diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index 72451f837d..fb1e11f413 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -48,3 +48,25 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result { pub fn mse(inp: &Tensor, target: &Tensor) -> Result { (inp - target)?.sqr()?.mean_all() } + +/// The binary cross-entropy with logit loss. +/// +/// Arguments +/// +/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number +/// of categories. This is expected to raw logits. +/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number +/// of categories. +/// +/// The resulting tensor is a scalar containing the average value over the batch. +pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result { + let inp = crate::ops::sigmoid(inp)?; + + let left_side = target * inp.log()?; + let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?; + + let loss = left_side? + right_side?; + let loss = loss?.neg()?.mean_all()?; + + Ok(loss) +} diff --git a/candle-nn/tests/loss.rs b/candle-nn/tests/loss.rs index d772f1768c..ccfc029fdd 100644 --- a/candle-nn/tests/loss.rs +++ b/candle-nn/tests/loss.rs @@ -39,3 +39,50 @@ fn nll_and_cross_entropy() -> Result<()> { assert_eq!(to_vec0_round(&loss, 4)?, 1.1312); Ok(()) } + +/* Equivalent python code: +import torch +import torch.nn.functional as F + +inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178], + [ 0.0419, 0.0763, -1.0457, -1.6692], + [-1.0494, 0.8111, 1.5723, 1.2315], + [ 1.3081, 0.6641, 1.1802, -0.2547], + [ 0.5292, 0.7636, 0.3692, -0.8318]]) + +target = torch.Tensor([[0., 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.], + [0., 0., 1., 0.]]) + +print(F.binary_cross_entropy_with_logits(inp, target)) +*/ +#[test] +fn binary_cross_entropy_with_logit() -> Result<()> { + let cpu = Device::Cpu; + + let inp = [ + [2.3611f32, -0.8813, -0.5006, -0.2178], + [0.0419, 0.0763, -1.0457, -1.6692], + [-1.0494, 0.8111, 1.5723, 1.2315], + [1.3081, 0.6641, 1.1802, -0.2547], + [0.5292, 0.7636, 0.3692, -0.8318], + ]; + + let target = [ + [0.0f32, 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.], + [0., 0., 1., 0.], + ]; + + let inp = Tensor::new(&inp, &cpu)?; + let target = Tensor::new(&target, &cpu)?; + + let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?; + + assert_eq!(to_vec0_round(&loss, 4)?, 0.8224); + Ok(()) +}