diff --git a/solidago/src/solidago/preference_learning/generalized_bradley_terry.py b/solidago/src/solidago/preference_learning/generalized_bradley_terry.py index d737de0495..20c56c768f 100644 --- a/solidago/src/solidago/preference_learning/generalized_bradley_terry.py +++ b/solidago/src/solidago/preference_learning/generalized_bradley_terry.py @@ -50,7 +50,7 @@ def cumulant_generating_function_derivative(self) -> Callable[[npt.NDArray], npt @property @abstractmethod - def loss_function(self) -> Callable[[npt.NDArray, npt.NDArray, float], float]: + def log_likelihood_function(self) -> Callable[[npt.NDArray, npt.NDArray], float]: """The loss function definition is used only to compute uncertainties. """ @@ -59,12 +59,11 @@ def loss_increase_to_solve(self): """The root of this function is used to compute asymetric uncertainties by looking for the delta for which the loss is increased by 1. """ - loss_function = self.loss_function + ll_function = self.log_likelihood_function @njit - def f(delta, theta_diff, r, coord_indicator, loss_actual, norm2_actual, solution_actual): - norm2 = norm2_actual - solution_actual**2 + (solution_actual + delta) ** 2 - return loss_function(theta_diff + delta * coord_indicator, r, norm2) - loss_actual - 1.0 + def f(delta, theta_diff, r, coord_indicator, ll_actual): + return ll_function(theta_diff + delta * coord_indicator, r) - ll_actual - 1.0 return f @@ -143,17 +142,16 @@ def get_derivative_args(coord: int, sol: np.ndarray): uncertainties_left = np.empty_like(solution) uncertainties_right = np.empty_like(solution) - norm2_actual = (solution ** 2).sum() - loss_actual = self.loss_function(score_diff, r_actual, norm2_actual) + ll_actual = self.log_likelihood_function(score_diff, r_actual) - for coordinate, solution_coord in enumerate(solution): + for coordinate in range(len(solution)): comparison_indicator = ( (comparisons["entity_a_coord"] == coordinate).astype(int) - (comparisons["entity_b_coord"] == coordinate).astype(int) ).to_numpy() uncertainties_left[coordinate] = -1 * njit_brentq( self.loss_increase_to_solve, - args=(score_diff, r_actual, comparison_indicator, loss_actual, norm2_actual, solution_coord), + args=(score_diff, r_actual, comparison_indicator, ll_actual), xtol=self.convergence_error, a=-1.0, b=0.0, @@ -161,7 +159,7 @@ def get_derivative_args(coord: int, sol: np.ndarray): ) uncertainties_right[coordinate] = njit_brentq( self.loss_increase_to_solve, - args=(score_diff, r_actual, comparison_indicator, loss_actual, norm2_actual, solution_coord), + args=(score_diff, r_actual, comparison_indicator, ll_actual), xtol=self.convergence_error, a=0.0, b=1.0 @@ -256,16 +254,11 @@ def __init__( self.cumulant_generating_function_error = cumulant_generating_function_error @cached_property - def loss_function(self): - prior_std_dev = self.prior_std_dev - + def log_likelihood_function(self): @njit - def loss(score_diff, r, norm2): - return (np.log(np.sinh(score_diff) / score_diff) + r * score_diff).sum() + norm2 / ( - 2 * prior_std_dev**2 - ) - - return loss + def f(score_diff, r): + return (np.log(np.sinh(score_diff) / score_diff) + r * score_diff).sum() + return f @cached_property def cumulant_generating_function_derivative(self) -> Callable[[npt.NDArray], npt.NDArray]: