Skip to content

Commit

Permalink
Add Binary Cross Entropy With Logit Loss to nn crate (#1157)
Browse files Browse the repository at this point in the history
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting
  • Loading branch information
ToluClassics authored Oct 23, 2023
1 parent 25c3cc4 commit 86e1803
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
22 changes: 22 additions & 0 deletions candle-nn/src/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,25 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
(inp - target)?.sqr()?.mean_all()
}

/// The binary cross-entropy with logit loss.
///
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
/// of categories.
///
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let inp = crate::ops::sigmoid(inp)?;

let left_side = target * inp.log()?;
let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;

let loss = left_side? + right_side?;
let loss = loss?.neg()?.mean_all()?;

Ok(loss)
}
47 changes: 47 additions & 0 deletions candle-nn/tests/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,50 @@ fn nll_and_cross_entropy() -> Result<()> {
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
Ok(())
}

/* Equivalent python code:
import torch
import torch.nn.functional as F
inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
[ 0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[ 1.3081, 0.6641, 1.1802, -0.2547],
[ 0.5292, 0.7636, 0.3692, -0.8318]])
target = torch.Tensor([[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.]])
print(F.binary_cross_entropy_with_logits(inp, target))
*/
#[test]
fn binary_cross_entropy_with_logit() -> Result<()> {
let cpu = Device::Cpu;

let inp = [
[2.3611f32, -0.8813, -0.5006, -0.2178],
[0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[1.3081, 0.6641, 1.1802, -0.2547],
[0.5292, 0.7636, 0.3692, -0.8318],
];

let target = [
[0.0f32, 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
];

let inp = Tensor::new(&inp, &cpu)?;
let target = Tensor::new(&target, &cpu)?;

let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?;

assert_eq!(to_vec0_round(&loss, 4)?, 0.8224);
Ok(())
}

0 comments on commit 86e1803

Please sign in to comment.