Skip to content

Commit

Permalink
gbt: use negative log likelihood term instead of loss to compute unce…
Browse files Browse the repository at this point in the history
…rtainties
  • Loading branch information
amatissart committed May 18, 2024
1 parent 18a2473 commit bf95cbb
Showing 1 changed file with 12 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def cumulant_generating_function_derivative(self) -> Callable[[npt.NDArray], npt

@property
@abstractmethod
def loss_function(self) -> Callable[[npt.NDArray, npt.NDArray, float], float]:
def log_likelihood_function(self) -> Callable[[npt.NDArray, npt.NDArray], float]:
"""The loss function definition is used only to compute uncertainties.
"""

Expand All @@ -59,12 +59,11 @@ def loss_increase_to_solve(self):
"""The root of this function is used to compute asymetric uncertainties
by looking for the delta for which the loss is increased by 1.
"""
loss_function = self.loss_function
ll_function = self.log_likelihood_function

@njit
def f(delta, theta_diff, r, coord_indicator, loss_actual, norm2_actual, solution_actual):
norm2 = norm2_actual - solution_actual**2 + (solution_actual + delta) ** 2
return loss_function(theta_diff + delta * coord_indicator, r, norm2) - loss_actual - 1.0
def f(delta, theta_diff, r, coord_indicator, ll_actual):
return ll_function(theta_diff + delta * coord_indicator, r) - ll_actual - 1.0

return f

Expand Down Expand Up @@ -143,25 +142,24 @@ def get_derivative_args(coord: int, sol: np.ndarray):

uncertainties_left = np.empty_like(solution)
uncertainties_right = np.empty_like(solution)
norm2_actual = (solution ** 2).sum()
loss_actual = self.loss_function(score_diff, r_actual, norm2_actual)
ll_actual = self.log_likelihood_function(score_diff, r_actual)

for coordinate, solution_coord in enumerate(solution):
for coordinate in range(len(solution)):
comparison_indicator = (
(comparisons["entity_a_coord"] == coordinate).astype(int)
- (comparisons["entity_b_coord"] == coordinate).astype(int)
).to_numpy()
uncertainties_left[coordinate] = -1 * njit_brentq(
self.loss_increase_to_solve,
args=(score_diff, r_actual, comparison_indicator, loss_actual, norm2_actual, solution_coord),
args=(score_diff, r_actual, comparison_indicator, ll_actual),
xtol=self.convergence_error,
a=-1.0,
b=0.0,
ascending=False,
)
uncertainties_right[coordinate] = njit_brentq(
self.loss_increase_to_solve,
args=(score_diff, r_actual, comparison_indicator, loss_actual, norm2_actual, solution_coord),
args=(score_diff, r_actual, comparison_indicator, ll_actual),
xtol=self.convergence_error,
a=0.0,
b=1.0
Expand Down Expand Up @@ -256,16 +254,11 @@ def __init__(
self.cumulant_generating_function_error = cumulant_generating_function_error

@cached_property
def loss_function(self):
prior_std_dev = self.prior_std_dev

def log_likelihood_function(self):
@njit
def loss(score_diff, r, norm2):
return (np.log(np.sinh(score_diff) / score_diff) + r * score_diff).sum() + norm2 / (
2 * prior_std_dev**2
)

return loss
def f(score_diff, r):
return (np.log(np.sinh(score_diff) / score_diff) + r * score_diff).sum()
return f

@cached_property
def cumulant_generating_function_derivative(self) -> Callable[[npt.NDArray], npt.NDArray]:
Expand Down

0 comments on commit bf95cbb

Please sign in to comment.