Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2nd Order AD] Regularization term in loss function #449

Open
avik-pal opened this issue Jan 1, 2025 · 7 comments
Open

[2nd Order AD] Regularization term in loss function #449

avik-pal opened this issue Jan 1, 2025 · 7 comments

Comments

@avik-pal
Copy link
Collaborator

avik-pal commented Jan 1, 2025

xref https://discourse.julialang.org/t/second-order-gradient-with-lux-zygote-cuda-enzyme/124301

using Lux, Random, OneHotArrays
using Reactant, Enzyme

model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 2)
    )
)

dev = reactant_device(; force=true)

ps, st = Lux.setup(Random.default_rng(), model) |> dev;

x = randn(Float32, 28, 28, 1, 32) |> dev;
δ = randn(Float32, 28, 28, 1, 32) |> dev;
y = onehotbatch(rand((1, 2), 32), 1:2) |> dev;

const celoss = CrossEntropyLoss(; logits=true)
const regloss = MSELoss()

function loss_function(model, ps, st, x, y)
    pred, _ = model(x, ps, st)
    return celoss(pred, y)
end

function ∂xloss_function(model, ps, st, x, δ, y)
    ∂x = Enzyme.gradient(
        Reverse, loss_function, Const(model), Const(ps), Const(st), x, Const(y))[4]
    return regloss(∂x, δ) + loss_function(model, ps, st, x, y)
end

@code_hlo ∂xloss_function(model, ps, st, x, δ, y)

function ∂∂xloss_function(model, ps, st, x, δ, y)
    return Enzyme.gradient(
        Reverse, ∂xloss_function, Const(model), ps, Const(st), Const(x), Const(δ), Const(y)
    )[2]
end

@code_hlo optimize=false ∂∂xloss_function(model, ps, st, x, δ, y)

@code_hlo ∂∂xloss_function(model, ps, st, x, δ, y)

# error: could not compute the adjoint for this operation "enzyme.push"(%56, %125) : (!enzyme.Cache<tensor<32xf32>>, tensor<32xf32>) -> ()
# loc("subtract"("/mnt/software/lux/Reactant.jl/src/Ops.jl":193:0)): error: could not compute the adjoint for this operation "enzyme.push"(%51, %123) : (!enzyme.Cache<tensor<2x32xf32>>, tensor<2x32xf32>) -> ()
# error: could not compute the adjoint for this operation "enzyme.push"(%42, %arg8) : (!enzyme.Cache<tensor<84x2xf32>>, tensor<84x2xf32>) -> ()
# error: could not compute the adjoint for this operation "enzyme.push"(%34, %arg6) : (!enzyme.Cache<tensor<128x84xf32>>, tensor<128x84xf32>) -> ()
# error: could not compute the adjoint for this operation "enzyme.push"(%26, %arg4) : (!enzyme.Cache<tensor<256x128xf32>>, tensor<256x128xf32>) -> ()
# error: could not compute the adjoint for this operation "enzyme.push"(%19, %104) : (!enzyme.Cache<tensor<8x8x16x32xf32>>, tensor<8x8x16x32xf32>) -> ()
# loc("reverse"("/mnt/software/lux/Reactant.jl/src/Ops.jl":1038:0)): error: could not compute the adjoint for this operation "enzyme.push"(%11, %99) : (!enzyme.Cache<tensor<5x5x6x16xf32>>, tensor<5x5x6x16xf32>) -> ()
# error: could not compute the adjoint for this operation "enzyme.push"(%8, %97) : (!enzyme.Cache<tensor<24x24x6x32xf32>>, tensor<24x24x6x32xf32>) -> ()
# loc("reverse"("/mnt/software/lux/Reactant.jl/src/Ops.jl":1038:0)): error: could not compute the adjoint for this operation "enzyme.push"(%0, %92) : (!enzyme.Cache<tensor<5x5x1x6xf32>>, tensor<5x5x1x6xf32>) -> ()
Unoptimized IR

module {
  func.func private @"+_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar1"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar1(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar2"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar2(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar3"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar3(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar4"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"-_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.subtract %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @exp_fast_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.exponential %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @log_fast_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.log %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"-_broadcast_scalar1"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.subtract %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"*_broadcast_scalar"(%arg0: tensor<i1>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i1>, tensor<f32>) {
    %0 = stablehlo.convert %arg0 : (tensor<i1>) -> tensor<f32>
    %1 = stablehlo.multiply %0, %arg1 : tensor<f32>
    return %1, %arg0, %arg1 : tensor<f32>, tensor<i1>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar1(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"-_broadcast_scalar2"(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.negate %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"Const{typeof(loss_function)}(Main.loss_function)_autodiff"(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x2xi1>) -> (tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
    %0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %2 = stablehlo.transpose %arg2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %4 = stablehlo.transpose %arg4, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %5 = stablehlo.transpose %arg5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %6 = stablehlo.transpose %arg6, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %7 = stablehlo.transpose %arg7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %10 = stablehlo.transpose %arg10, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %11 = stablehlo.transpose %arg11, dims = [1, 0] : (tensor<32x2xi1>) -> tensor<2x32xi1>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x32xf32>
    %12 = stablehlo.reverse %0, dims = [0, 1] : tensor<5x5x1x6xf32>
    %13 = stablehlo.convolution(%10, %12) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<28x28x1x32xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x32xf32>
    %14 = stablehlo.transpose %1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %15 = stablehlo.reshape %14 : (tensor<6xf32>) -> tensor<1x6x1x1xf32>
    %16 = stablehlo.transpose %15, dims = [3, 2, 1, 0] : (tensor<1x6x1x1xf32>) -> tensor<1x1x6x1xf32>
    %17 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2, 3] : (tensor<1x1x6x1xf32>) -> tensor<24x24x6x32xf32>
    %18:3 = enzyme.batch @"+_broadcast_scalar"(%13, %17) {batch_shape = array<i64: 24, 24, 6, 32>} : (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>) -> (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>)
    %19:2 = enzyme.batch @relu_broadcast_scalar(%18#0) {batch_shape = array<i64: 24, 24, 6, 32>} : (tensor<24x24x6x32xf32>) -> (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>)
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x6x32xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x6x32xf32>
    %cst_2 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %20 = "stablehlo.reduce_window"(%19#0, %cst_2) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg12: tensor<f32>, %arg13: tensor<f32>):
      %86 = stablehlo.maximum %arg12, %arg13 : tensor<f32>
      stablehlo.return %86 : tensor<f32>
    }) : (tensor<24x24x6x32xf32>, tensor<f32>) -> tensor<12x12x6x32xf32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x32xf32>
    %21 = stablehlo.reverse %2, dims = [0, 1] : tensor<5x5x6x16xf32>
    %22 = stablehlo.convolution(%20, %21) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x32xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x32xf32>
    %23 = stablehlo.transpose %3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %24 = stablehlo.reshape %23 : (tensor<16xf32>) -> tensor<1x16x1x1xf32>
    %25 = stablehlo.transpose %24, dims = [3, 2, 1, 0] : (tensor<1x16x1x1xf32>) -> tensor<1x1x16x1xf32>
    %26 = stablehlo.broadcast_in_dim %25, dims = [0, 1, 2, 3] : (tensor<1x1x16x1xf32>) -> tensor<8x8x16x32xf32>
    %27:3 = enzyme.batch @"+_broadcast_scalar1"(%22, %26) {batch_shape = array<i64: 8, 8, 16, 32>} : (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>) -> (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>)
    %28:2 = enzyme.batch @relu_broadcast_scalar1(%27#0) {batch_shape = array<i64: 8, 8, 16, 32>} : (tensor<8x8x16x32xf32>) -> (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>)
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<4x4x16x32xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<4x4x16x32xf32>
    %cst_6 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %29 = "stablehlo.reduce_window"(%28#0, %cst_6) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg12: tensor<f32>, %arg13: tensor<f32>):
      %86 = stablehlo.maximum %arg12, %arg13 : tensor<f32>
      stablehlo.return %86 : tensor<f32>
    }) : (tensor<8x8x16x32xf32>, tensor<f32>) -> tensor<4x4x16x32xf32>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<128x32xf32>
    %30 = stablehlo.transpose %29, dims = [3, 2, 1, 0] : (tensor<4x4x16x32xf32>) -> tensor<32x16x4x4xf32>
    %31 = stablehlo.reshape %30 : (tensor<32x16x4x4xf32>) -> tensor<32x256xf32>
    %32 = stablehlo.transpose %31, dims = [1, 0] : (tensor<32x256xf32>) -> tensor<256x32xf32>
    %33 = stablehlo.dot_general %4, %32, contracting_dims = [1] x [0] : (tensor<128x256xf32>, tensor<256x32xf32>) -> tensor<128x32xf32>
    %34 = stablehlo.transpose %5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %35 = stablehlo.reshape %34 : (tensor<128xf32>) -> tensor<1x128xf32>
    %36 = stablehlo.transpose %35, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %37 = stablehlo.broadcast_in_dim %36, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x32xf32>
    %38:3 = enzyme.batch @"+_broadcast_scalar2"(%33, %37) {batch_shape = array<i64: 128, 32>} : (tensor<128x32xf32>, tensor<128x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>, tensor<128x32xf32>)
    %39:2 = enzyme.batch @relu_broadcast_scalar2(%38#0) {batch_shape = array<i64: 128, 32>} : (tensor<128x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>)
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<84x32xf32>
    %40 = stablehlo.dot_general %6, %39#0, contracting_dims = [1] x [0] : (tensor<84x128xf32>, tensor<128x32xf32>) -> tensor<84x32xf32>
    %41 = stablehlo.transpose %7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %42 = stablehlo.reshape %41 : (tensor<84xf32>) -> tensor<1x84xf32>
    %43 = stablehlo.transpose %42, dims = [1, 0] : (tensor<1x84xf32>) -> tensor<84x1xf32>
    %44 = stablehlo.broadcast_in_dim %43, dims = [0, 1] : (tensor<84x1xf32>) -> tensor<84x32xf32>
    %45:3 = enzyme.batch @"+_broadcast_scalar3"(%40, %44) {batch_shape = array<i64: 84, 32>} : (tensor<84x32xf32>, tensor<84x32xf32>) -> (tensor<84x32xf32>, tensor<84x32xf32>, tensor<84x32xf32>)
    %46:2 = enzyme.batch @relu_broadcast_scalar3(%45#0) {batch_shape = array<i64: 84, 32>} : (tensor<84x32xf32>) -> (tensor<84x32xf32>, tensor<84x32xf32>)
    %cst_9 = stablehlo.constant dense<0.000000e+00> : tensor<2x32xf32>
    %47 = stablehlo.dot_general %8, %46#0, contracting_dims = [1] x [0] : (tensor<2x84xf32>, tensor<84x32xf32>) -> tensor<2x32xf32>
    %48 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xf32>) -> tensor<2x32xf32>
    %49:3 = enzyme.batch @"+_broadcast_scalar4"(%47, %48) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_11 = stablehlo.constant dense<1.1920929E-7> : tensor<f32>
    %cst_12 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<2x32xf32>
    %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_15 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %50 = enzyme.batch @identity_broadcast_scalar(%49#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> tensor<2x32xf32>
    %51 = stablehlo.reduce(%50 init: %cst_15) applies stablehlo.maximum across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %52 = stablehlo.transpose %51, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %53 = stablehlo.reshape %52 : (tensor<32xf32>) -> tensor<32x1xf32>
    %54 = stablehlo.transpose %53, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %55 = stablehlo.broadcast_in_dim %54, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<2x32xf32>
    %56:3 = enzyme.batch @"-_broadcast_scalar"(%49#0, %55) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %57:2 = enzyme.batch @exp_fast_broadcast_scalar(%56#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>)
    %58 = stablehlo.reduce(%57#0 init: %cst_16) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %59 = stablehlo.transpose %58, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %60 = stablehlo.reshape %59 : (tensor<32xf32>) -> tensor<32x1xf32>
    %61 = stablehlo.transpose %60, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %62:2 = enzyme.batch @log_fast_broadcast_scalar(%61) {batch_shape = array<i64: 1, 32>} : (tensor<1x32xf32>) -> (tensor<1x32xf32>, tensor<1x32xf32>)
    %63 = stablehlo.broadcast_in_dim %62#0, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<2x32xf32>
    %64:3 = enzyme.batch @"-_broadcast_scalar1"(%57#1, %63) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %65:3 = enzyme.batch @"*_broadcast_scalar"(%11, %64#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xi1>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xi1>, tensor<2x32xf32>)
    %cst_17 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %66 = enzyme.batch @identity_broadcast_scalar1(%65#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> tensor<2x32xf32>
    %67 = stablehlo.reduce(%66 init: %cst_17) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %68 = stablehlo.transpose %67, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %69 = stablehlo.reshape %68 : (tensor<32xf32>) -> tensor<32x1xf32>
    %70 = stablehlo.transpose %69, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %71:2 = enzyme.batch @"-_broadcast_scalar2"(%70) {batch_shape = array<i64: 1, 32>} : (tensor<1x32xf32>) -> (tensor<1x32xf32>, tensor<1x32xf32>)
    %72 = stablehlo.reduce(%71#0 init: %cst_18) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x32xf32>, tensor<f32>) -> tensor<f32>
    %cst_19 = stablehlo.constant dense<3.200000e+01> : tensor<f32>
    %73 = stablehlo.divide %72, %cst_19 : tensor<f32>
    %74 = stablehlo.transpose %0, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %75 = stablehlo.transpose %1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %76 = stablehlo.transpose %2, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %77 = stablehlo.transpose %3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %78 = stablehlo.transpose %4, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %79 = stablehlo.transpose %5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %80 = stablehlo.transpose %6, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %81 = stablehlo.transpose %7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %82 = stablehlo.transpose %8, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %83 = stablehlo.transpose %9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %84 = stablehlo.transpose %10, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %85 = stablehlo.transpose %65#1, dims = [1, 0] : (tensor<2x32xi1>) -> tensor<32x2xi1>
    return %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85 : tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
  }
  func.func private @l2_distance_loss_broadcast_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.subtract %arg0, %arg1 : tensor<f32>
    %1 = stablehlo.abs %0 : tensor<f32>
    %2 = stablehlo.multiply %1, %1 : tensor<f32>
    return %2, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar2(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"+_broadcast_scalar5"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar4(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar6"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar5(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar7"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar6(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar8"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar7(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar9"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar3(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"-_broadcast_scalar3"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.subtract %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @exp_fast_broadcast_scalar1(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.exponential %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @log_fast_broadcast_scalar1(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.log %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"-_broadcast_scalar4"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.subtract %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"*_broadcast_scalar1"(%arg0: tensor<i1>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i1>, tensor<f32>) {
    %0 = stablehlo.convert %arg0 : (tensor<i1>) -> tensor<f32>
    %1 = stablehlo.multiply %0, %arg1 : tensor<f32>
    return %1, %arg0, %arg1 : tensor<f32>, tensor<i1>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar4(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"-_broadcast_scalar5"(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.negate %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"Const{typeof(\E2\88\82xloss_function)}(Main.\E2\88\82xloss_function)_autodiff"(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x1x28x28xf32>, %arg12: tensor<32x2xi1>) -> (tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
    %0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %2 = stablehlo.transpose %arg2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %4 = stablehlo.transpose %arg4, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %5 = stablehlo.transpose %arg5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %6 = stablehlo.transpose %arg6, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %7 = stablehlo.transpose %arg7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %10 = stablehlo.transpose %arg10, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %11 = stablehlo.transpose %arg11, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %12 = stablehlo.transpose %arg12, dims = [1, 0] : (tensor<32x2xi1>) -> tensor<2x32xi1>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<28x28x1x32xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %13 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<28x28x1x32xf32>
    %cst_2 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %14 = stablehlo.transpose %0, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %15 = stablehlo.transpose %1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %16 = stablehlo.transpose %2, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %17 = stablehlo.transpose %3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %18 = stablehlo.transpose %4, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %19 = stablehlo.transpose %5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %20 = stablehlo.transpose %6, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %21 = stablehlo.transpose %7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %22 = stablehlo.transpose %8, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %23 = stablehlo.transpose %9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %24 = stablehlo.transpose %10, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %25 = stablehlo.transpose %12, dims = [1, 0] : (tensor<2x32xi1>) -> tensor<32x2xi1>
    %26 = stablehlo.transpose %13, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %27:13 = enzyme.autodiff @"Const{typeof(loss_function)}(Main.loss_function)_autodiff"(%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %cst_2, %26) {activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<f32>, tensor<32x1x28x28xf32>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<32x1x28x28xf32>)
    %28 = stablehlo.transpose %27#0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %29 = stablehlo.transpose %27#1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %30 = stablehlo.transpose %27#2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %31 = stablehlo.transpose %27#3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %32 = stablehlo.transpose %27#4, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %33 = stablehlo.transpose %27#5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %34 = stablehlo.transpose %27#6, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %35 = stablehlo.transpose %27#7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %36 = stablehlo.transpose %27#8, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %37 = stablehlo.transpose %27#9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %38 = stablehlo.transpose %27#10, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %39 = stablehlo.transpose %27#11, dims = [1, 0] : (tensor<32x2xi1>) -> tensor<2x32xi1>
    %40 = stablehlo.transpose %27#12, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %41:3 = enzyme.batch @l2_distance_loss_broadcast_scalar(%40, %11) {batch_shape = array<i64: 28, 28, 1, 32>} : (tensor<28x28x1x32xf32>, tensor<28x28x1x32xf32>) -> (tensor<28x28x1x32xf32>, tensor<28x28x1x32xf32>, tensor<28x28x1x32xf32>)
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %42 = enzyme.batch @identity_broadcast_scalar2(%41#0) {batch_shape = array<i64: 28, 28, 1, 32>} : (tensor<28x28x1x32xf32>) -> tensor<28x28x1x32xf32>
    %43 = stablehlo.reduce(%42 init: %cst_3) applies stablehlo.add across dimensions = [0, 1, 2, 3] : (tensor<28x28x1x32xf32>, tensor<f32>) -> tensor<f32>
    %cst_4 = stablehlo.constant dense<2.508800e+04> : tensor<f32>
    %44 = stablehlo.divide %43, %cst_4 : tensor<f32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x32xf32>
    %45 = stablehlo.reverse %28, dims = [0, 1] : tensor<5x5x1x6xf32>
    %46 = stablehlo.convolution(%38, %45) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<28x28x1x32xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x32xf32>
    %47 = stablehlo.transpose %29, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %48 = stablehlo.reshape %47 : (tensor<6xf32>) -> tensor<1x6x1x1xf32>
    %49 = stablehlo.transpose %48, dims = [3, 2, 1, 0] : (tensor<1x6x1x1xf32>) -> tensor<1x1x6x1xf32>
    %50 = stablehlo.broadcast_in_dim %49, dims = [0, 1, 2, 3] : (tensor<1x1x6x1xf32>) -> tensor<24x24x6x32xf32>
    %51:3 = enzyme.batch @"+_broadcast_scalar5"(%46, %50) {batch_shape = array<i64: 24, 24, 6, 32>} : (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>) -> (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>)
    %52:2 = enzyme.batch @relu_broadcast_scalar4(%51#0) {batch_shape = array<i64: 24, 24, 6, 32>} : (tensor<24x24x6x32xf32>) -> (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>)
    %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x6x32xf32>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x6x32xf32>
    %cst_8 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %53 = "stablehlo.reduce_window"(%52#0, %cst_8) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
      %121 = stablehlo.maximum %arg13, %arg14 : tensor<f32>
      stablehlo.return %121 : tensor<f32>
    }) : (tensor<24x24x6x32xf32>, tensor<f32>) -> tensor<12x12x6x32xf32>
    %cst_9 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x32xf32>
    %54 = stablehlo.reverse %30, dims = [0, 1] : tensor<5x5x6x16xf32>
    %55 = stablehlo.convolution(%53, %54) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x32xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x32xf32>
    %56 = stablehlo.transpose %31, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %57 = stablehlo.reshape %56 : (tensor<16xf32>) -> tensor<1x16x1x1xf32>
    %58 = stablehlo.transpose %57, dims = [3, 2, 1, 0] : (tensor<1x16x1x1xf32>) -> tensor<1x1x16x1xf32>
    %59 = stablehlo.broadcast_in_dim %58, dims = [0, 1, 2, 3] : (tensor<1x1x16x1xf32>) -> tensor<8x8x16x32xf32>
    %60:3 = enzyme.batch @"+_broadcast_scalar6"(%55, %59) {batch_shape = array<i64: 8, 8, 16, 32>} : (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>) -> (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>)
    %61:2 = enzyme.batch @relu_broadcast_scalar5(%60#0) {batch_shape = array<i64: 8, 8, 16, 32>} : (tensor<8x8x16x32xf32>) -> (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>)
    %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<4x4x16x32xf32>
    %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<4x4x16x32xf32>
    %cst_12 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %62 = "stablehlo.reduce_window"(%61#0, %cst_12) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
      %121 = stablehlo.maximum %arg13, %arg14 : tensor<f32>
      stablehlo.return %121 : tensor<f32>
    }) : (tensor<8x8x16x32xf32>, tensor<f32>) -> tensor<4x4x16x32xf32>
    %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<128x32xf32>
    %63 = stablehlo.transpose %62, dims = [3, 2, 1, 0] : (tensor<4x4x16x32xf32>) -> tensor<32x16x4x4xf32>
    %64 = stablehlo.reshape %63 : (tensor<32x16x4x4xf32>) -> tensor<32x256xf32>
    %65 = stablehlo.transpose %64, dims = [1, 0] : (tensor<32x256xf32>) -> tensor<256x32xf32>
    %66 = stablehlo.dot_general %32, %65, contracting_dims = [1] x [0] : (tensor<128x256xf32>, tensor<256x32xf32>) -> tensor<128x32xf32>
    %67 = stablehlo.transpose %33, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %68 = stablehlo.reshape %67 : (tensor<128xf32>) -> tensor<1x128xf32>
    %69 = stablehlo.transpose %68, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %70 = stablehlo.broadcast_in_dim %69, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x32xf32>
    %71:3 = enzyme.batch @"+_broadcast_scalar7"(%66, %70) {batch_shape = array<i64: 128, 32>} : (tensor<128x32xf32>, tensor<128x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>, tensor<128x32xf32>)
    %72:2 = enzyme.batch @relu_broadcast_scalar6(%71#0) {batch_shape = array<i64: 128, 32>} : (tensor<128x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>)
    %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<84x32xf32>
    %73 = stablehlo.dot_general %34, %72#0, contracting_dims = [1] x [0] : (tensor<84x128xf32>, tensor<128x32xf32>) -> tensor<84x32xf32>
    %74 = stablehlo.transpose %35, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %75 = stablehlo.reshape %74 : (tensor<84xf32>) -> tensor<1x84xf32>
    %76 = stablehlo.transpose %75, dims = [1, 0] : (tensor<1x84xf32>) -> tensor<84x1xf32>
    %77 = stablehlo.broadcast_in_dim %76, dims = [0, 1] : (tensor<84x1xf32>) -> tensor<84x32xf32>
    %78:3 = enzyme.batch @"+_broadcast_scalar8"(%73, %77) {batch_shape = array<i64: 84, 32>} : (tensor<84x32xf32>, tensor<84x32xf32>) -> (tensor<84x32xf32>, tensor<84x32xf32>, tensor<84x32xf32>)
    %79:2 = enzyme.batch @relu_broadcast_scalar7(%78#0) {batch_shape = array<i64: 84, 32>} : (tensor<84x32xf32>) -> (tensor<84x32xf32>, tensor<84x32xf32>)
    %cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<2x32xf32>
    %80 = stablehlo.dot_general %36, %79#0, contracting_dims = [1] x [0] : (tensor<2x84xf32>, tensor<84x32xf32>) -> tensor<2x32xf32>
    %81 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<2xf32>) -> tensor<2x32xf32>
    %82:3 = enzyme.batch @"+_broadcast_scalar9"(%80, %81) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_17 = stablehlo.constant dense<1.1920929E-7> : tensor<f32>
    %cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_19 = stablehlo.constant dense<0.000000e+00> : tensor<2x32xf32>
    %cst_20 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_21 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %83 = enzyme.batch @identity_broadcast_scalar3(%82#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> tensor<2x32xf32>
    %84 = stablehlo.reduce(%83 init: %cst_21) applies stablehlo.maximum across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %85 = stablehlo.transpose %84, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %86 = stablehlo.reshape %85 : (tensor<32xf32>) -> tensor<32x1xf32>
    %87 = stablehlo.transpose %86, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %88 = stablehlo.broadcast_in_dim %87, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<2x32xf32>
    %89:3 = enzyme.batch @"-_broadcast_scalar3"(%82#0, %88) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %cst_22 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %90:2 = enzyme.batch @exp_fast_broadcast_scalar1(%89#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>)
    %91 = stablehlo.reduce(%90#0 init: %cst_22) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %92 = stablehlo.transpose %91, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %93 = stablehlo.reshape %92 : (tensor<32xf32>) -> tensor<32x1xf32>
    %94 = stablehlo.transpose %93, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %95:2 = enzyme.batch @log_fast_broadcast_scalar1(%94) {batch_shape = array<i64: 1, 32>} : (tensor<1x32xf32>) -> (tensor<1x32xf32>, tensor<1x32xf32>)
    %96 = stablehlo.broadcast_in_dim %95#0, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<2x32xf32>
    %97:3 = enzyme.batch @"-_broadcast_scalar4"(%90#1, %96) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %98:3 = enzyme.batch @"*_broadcast_scalar1"(%39, %97#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xi1>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xi1>, tensor<2x32xf32>)
    %cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %99 = enzyme.batch @identity_broadcast_scalar4(%98#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> tensor<2x32xf32>
    %100 = stablehlo.reduce(%99 init: %cst_23) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %101 = stablehlo.transpose %100, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %102 = stablehlo.reshape %101 : (tensor<32xf32>) -> tensor<32x1xf32>
    %103 = stablehlo.transpose %102, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %104:2 = enzyme.batch @"-_broadcast_scalar5"(%103) {batch_shape = array<i64: 1, 32>} : (tensor<1x32xf32>) -> (tensor<1x32xf32>, tensor<1x32xf32>)
    %105 = stablehlo.reduce(%104#0 init: %cst_24) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x32xf32>, tensor<f32>) -> tensor<f32>
    %cst_25 = stablehlo.constant dense<3.200000e+01> : tensor<f32>
    %106 = stablehlo.divide %105, %cst_25 : tensor<f32>
    %107 = stablehlo.add %44, %106 : tensor<f32>
    %108 = stablehlo.transpose %28, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %109 = stablehlo.transpose %29, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %110 = stablehlo.transpose %30, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %111 = stablehlo.transpose %31, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %112 = stablehlo.transpose %32, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %113 = stablehlo.transpose %33, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %114 = stablehlo.transpose %34, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %115 = stablehlo.transpose %35, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %116 = stablehlo.transpose %36, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %117 = stablehlo.transpose %37, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %118 = stablehlo.transpose %38, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %119 = stablehlo.transpose %41#2, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %120 = stablehlo.transpose %98#1, dims = [1, 0] : (tensor<2x32xi1>) -> tensor<32x2xi1>
    return %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120 : tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
  }
  func.func @main(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x1x28x28xf32>, %arg12: tensor<32x2xi1>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
    %0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %2 = stablehlo.transpose %arg2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %4 = stablehlo.transpose %arg4, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %5 = stablehlo.transpose %arg5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %6 = stablehlo.transpose %arg6, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %7 = stablehlo.transpose %arg7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %10 = stablehlo.transpose %arg10, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %11 = stablehlo.transpose %arg11, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %12 = stablehlo.transpose %arg12, dims = [1, 0] : (tensor<32x2xi1>) -> tensor<2x32xi1>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<5x5x1x6xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %13 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<5x5x1x6xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<6xf32>
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %14 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor<f32>) -> tensor<6xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<5x5x6x16xf32>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %15 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor<f32>) -> tensor<5x5x6x16xf32>
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_9 = stablehlo.constant dense<0.000000e+00> : tensor<16xf32>
    %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %16 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor<f32>) -> tensor<16xf32>
    %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_12 = stablehlo.constant dense<0.000000e+00> : tensor<128x256xf32>
    %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %17 = stablehlo.broadcast_in_dim %cst_13, dims = [] : (tensor<f32>) -> tensor<128x256xf32>
    %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<128xf32>
    %cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %18 = stablehlo.broadcast_in_dim %cst_16, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %cst_17 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<84x128xf32>
    %cst_19 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %19 = stablehlo.broadcast_in_dim %cst_19, dims = [] : (tensor<f32>) -> tensor<84x128xf32>
    %cst_20 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_21 = stablehlo.constant dense<0.000000e+00> : tensor<84xf32>
    %cst_22 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %20 = stablehlo.broadcast_in_dim %cst_22, dims = [] : (tensor<f32>) -> tensor<84xf32>
    %cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<2x84xf32>
    %cst_25 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %21 = stablehlo.broadcast_in_dim %cst_25, dims = [] : (tensor<f32>) -> tensor<2x84xf32>
    %cst_26 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_27 = stablehlo.constant dense<0.000000e+00> : tensor<2xf32>
    %cst_28 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %22 = stablehlo.broadcast_in_dim %cst_28, dims = [] : (tensor<f32>) -> tensor<2xf32>
    %cst_29 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %23 = stablehlo.transpose %0, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %24 = stablehlo.transpose %1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %25 = stablehlo.transpose %2, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %26 = stablehlo.transpose %3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %27 = stablehlo.transpose %4, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %28 = stablehlo.transpose %5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %29 = stablehlo.transpose %6, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %30 = stablehlo.transpose %7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %31 = stablehlo.transpose %8, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %32 = stablehlo.transpose %9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %33 = stablehlo.transpose %10, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %34 = stablehlo.transpose %11, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %35 = stablehlo.transpose %12, dims = [1, 0] : (tensor<2x32xi1>) -> tensor<32x2xi1>
    %36 = stablehlo.transpose %13, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %37 = stablehlo.transpose %14, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %38 = stablehlo.transpose %15, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %39 = stablehlo.transpose %16, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %40 = stablehlo.transpose %17, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %41 = stablehlo.transpose %18, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %42 = stablehlo.transpose %19, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %43 = stablehlo.transpose %20, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %44 = stablehlo.transpose %21, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %45 = stablehlo.transpose %22, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %46:23 = enzyme.autodiff @"Const{typeof(\E2\88\82xloss_function)}(Main.\E2\88\82xloss_function)_autodiff"(%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %cst_29, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>]} : (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>)
    %47 = stablehlo.transpose %46#0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %48 = stablehlo.transpose %46#1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %49 = stablehlo.transpose %46#2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %50 = stablehlo.transpose %46#3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %51 = stablehlo.transpose %46#4, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %52 = stablehlo.transpose %46#5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %53 = stablehlo.transpose %46#6, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %54 = stablehlo.transpose %46#7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %55 = stablehlo.transpose %46#8, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %56 = stablehlo.transpose %46#9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %57 = stablehlo.transpose %46#10, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %58 = stablehlo.transpose %46#11, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %59 = stablehlo.transpose %46#12, dims = [1, 0] : (tensor<32x2xi1>) -> tensor<2x32xi1>
    %60 = stablehlo.transpose %46#13, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %61 = stablehlo.transpose %46#14, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %62 = stablehlo.transpose %46#15, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %63 = stablehlo.transpose %46#16, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %64 = stablehlo.transpose %46#17, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %65 = stablehlo.transpose %46#18, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %66 = stablehlo.transpose %46#19, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %67 = stablehlo.transpose %46#20, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %68 = stablehlo.transpose %46#21, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %69 = stablehlo.transpose %46#22, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %70 = stablehlo.transpose %60, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %71 = stablehlo.transpose %61, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %72 = stablehlo.transpose %62, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %73 = stablehlo.transpose %63, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %74 = stablehlo.transpose %64, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %75 = stablehlo.transpose %65, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %76 = stablehlo.transpose %66, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %77 = stablehlo.transpose %67, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %78 = stablehlo.transpose %68, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %79 = stablehlo.transpose %69, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %80 = stablehlo.transpose %47, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %81 = stablehlo.transpose %48, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %82 = stablehlo.transpose %49, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %83 = stablehlo.transpose %50, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %84 = stablehlo.transpose %51, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %85 = stablehlo.transpose %52, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %86 = stablehlo.transpose %53, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %87 = stablehlo.transpose %54, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %88 = stablehlo.transpose %55, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %89 = stablehlo.transpose %56, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %90 = stablehlo.transpose %57, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %91 = stablehlo.transpose %58, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %92 = stablehlo.transpose %59, dims = [1, 0] : (tensor<2x32xi1>) -> tensor<32x2xi1>
    return %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92 : tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
  }
}

@avik-pal
Copy link
Collaborator Author

avik-pal commented Jan 1, 2025

a bit more minimal IR

module {
  func.func private @"Const{typeof(loss_function)}(Main.loss_function)_autodiff"(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x2xi1>) -> (tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<84x32xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x32xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x32xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x32xf32>
    %cst_3 = stablehlo.constant dense<3.200000e+01> : tensor<f32>
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_5 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %1 = stablehlo.transpose %arg2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %2 = stablehlo.convert %arg11 : (tensor<32x2xi1>) -> tensor<32x2xf32>
    %3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<32x2xf32>) -> tensor<2x32xf32>
    %4 = stablehlo.reverse %0, dims = [0, 1] : tensor<5x5x1x6xf32>
    %5 = stablehlo.convolution(%arg10, %4) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<32x1x28x28xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x32xf32>
    %6 = stablehlo.broadcast_in_dim %arg1, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x32xf32>
    %7 = stablehlo.add %5, %6 : tensor<24x24x6x32xf32>
    %8 = stablehlo.compare  LT, %7, %cst_2 : (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>) -> tensor<24x24x6x32xi1>
    %9 = stablehlo.select %8, %cst_2, %7 : tensor<24x24x6x32xi1>, tensor<24x24x6x32xf32>
    %10 = "stablehlo.reduce_window"(%9, %cst_5) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg12: tensor<f32>, %arg13: tensor<f32>):
      %47 = stablehlo.maximum %arg12, %arg13 : tensor<f32>
      stablehlo.return %47 : tensor<f32>
    }) : (tensor<24x24x6x32xf32>, tensor<f32>) -> tensor<12x12x6x32xf32>
    %11 = stablehlo.reverse %1, dims = [0, 1] : tensor<5x5x6x16xf32>
    %12 = stablehlo.convolution(%10, %11) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x32xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x32xf32>
    %13 = stablehlo.broadcast_in_dim %arg3, dims = [2] : (tensor<16xf32>) -> tensor<8x8x16x32xf32>
    %14 = stablehlo.add %12, %13 : tensor<8x8x16x32xf32>
    %15 = stablehlo.compare  LT, %14, %cst_1 : (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>) -> tensor<8x8x16x32xi1>
    %16 = stablehlo.select %15, %cst_1, %14 : tensor<8x8x16x32xi1>, tensor<8x8x16x32xf32>
    %17 = "stablehlo.reduce_window"(%16, %cst_5) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg12: tensor<f32>, %arg13: tensor<f32>):
      %47 = stablehlo.maximum %arg12, %arg13 : tensor<f32>
      stablehlo.return %47 : tensor<f32>
    }) : (tensor<8x8x16x32xf32>, tensor<f32>) -> tensor<4x4x16x32xf32>
    %18 = stablehlo.transpose %17, dims = [3, 2, 1, 0] : (tensor<4x4x16x32xf32>) -> tensor<32x16x4x4xf32>
    %19 = stablehlo.reshape %18 : (tensor<32x16x4x4xf32>) -> tensor<32x256xf32>
    %20 = stablehlo.dot_general %arg4, %19, contracting_dims = [0] x [1] : (tensor<256x128xf32>, tensor<32x256xf32>) -> tensor<128x32xf32>
    %21 = stablehlo.broadcast_in_dim %arg5, dims = [0] : (tensor<128xf32>) -> tensor<128x32xf32>
    %22 = stablehlo.add %20, %21 : tensor<128x32xf32>
    %23 = stablehlo.compare  LT, %22, %cst_0 : (tensor<128x32xf32>, tensor<128x32xf32>) -> tensor<128x32xi1>
    %24 = stablehlo.select %23, %cst_0, %22 : tensor<128x32xi1>, tensor<128x32xf32>
    %25 = stablehlo.dot_general %arg6, %24, contracting_dims = [0] x [0] : (tensor<128x84xf32>, tensor<128x32xf32>) -> tensor<84x32xf32>
    %26 = stablehlo.broadcast_in_dim %arg7, dims = [0] : (tensor<84xf32>) -> tensor<84x32xf32>
    %27 = stablehlo.add %25, %26 : tensor<84x32xf32>
    %28 = stablehlo.compare  LT, %27, %cst : (tensor<84x32xf32>, tensor<84x32xf32>) -> tensor<84x32xi1>
    %29 = stablehlo.select %28, %cst, %27 : tensor<84x32xi1>, tensor<84x32xf32>
    %30 = stablehlo.dot_general %arg8, %29, contracting_dims = [0] x [0] : (tensor<84x2xf32>, tensor<84x32xf32>) -> tensor<2x32xf32>
    %31 = stablehlo.broadcast_in_dim %arg9, dims = [0] : (tensor<2xf32>) -> tensor<2x32xf32>
    %32 = stablehlo.add %30, %31 : tensor<2x32xf32>
    %33 = stablehlo.reduce(%32 init: %cst_5) applies stablehlo.maximum across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %34 = stablehlo.broadcast_in_dim %33, dims = [1] : (tensor<32xf32>) -> tensor<2x32xf32>
    %35 = stablehlo.subtract %32, %34 : tensor<2x32xf32>
    %36 = stablehlo.exponential %35 : tensor<2x32xf32>
    %37 = stablehlo.reduce(%36 init: %cst_4) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %38 = stablehlo.log %37 : tensor<32xf32>
    %39 = stablehlo.broadcast_in_dim %38, dims = [1] : (tensor<32xf32>) -> tensor<2x32xf32>
    %40 = stablehlo.subtract %35, %39 : tensor<2x32xf32>
    %41 = stablehlo.multiply %3, %40 : tensor<2x32xf32>
    %42 = stablehlo.reduce(%41 init: %cst_4) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %43 = stablehlo.negate %42 : tensor<32xf32>
    %44 = stablehlo.reshape %43 : (tensor<32xf32>) -> tensor<1x32xf32>
    %45 = stablehlo.reduce(%44 init: %cst_4) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x32xf32>, tensor<f32>) -> tensor<f32>
    %46 = stablehlo.divide %45, %cst_3 : tensor<f32>
    return %46, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11 : tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
  }
  func.func private @"Const{typeof(\E2\88\82xloss_function)}(Main.\E2\88\82xloss_function)_autodiff"(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x1x28x28xf32>, %arg12: tensor<32x2xi1>) -> (tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<84x32xf32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x32xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x32xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x32xf32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<32x1x28x28xf32>
    %cst_4 = stablehlo.constant dense<3.200000e+01> : tensor<f32>
    %cst_5 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %cst_6 = stablehlo.constant dense<2.508800e+04> : tensor<f32>
    %cst_7 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.transpose %arg11, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %1:13 = enzyme.autodiff @"Const{typeof(loss_function)}(Main.loss_function)_autodiff"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg12, %cst_7, %cst_3) {activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<f32>, tensor<32x1x28x28xf32>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<32x1x28x28xf32>)
    %2 = stablehlo.transpose %1#0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %3 = stablehlo.transpose %1#2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %4 = stablehlo.convert %1#11 : (tensor<32x2xi1>) -> tensor<32x2xf32>
    %5 = stablehlo.transpose %4, dims = [1, 0] : (tensor<32x2xf32>) -> tensor<2x32xf32>
    %6 = stablehlo.transpose %1#12, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %7 = stablehlo.subtract %6, %0 : tensor<28x28x1x32xf32>
    %8 = stablehlo.abs %7 : tensor<28x28x1x32xf32>
    %9 = stablehlo.multiply %8, %8 : tensor<28x28x1x32xf32>
    %10 = stablehlo.reduce(%9 init: %cst_8) applies stablehlo.add across dimensions = [0, 1, 2, 3] : (tensor<28x28x1x32xf32>, tensor<f32>) -> tensor<f32>
    %11 = stablehlo.divide %10, %cst_6 : tensor<f32>
    %12 = stablehlo.reverse %2, dims = [0, 1] : tensor<5x5x1x6xf32>
    %13 = stablehlo.convolution(%1#10, %12) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<32x1x28x28xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x32xf32>
    %14 = stablehlo.broadcast_in_dim %1#1, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x32xf32>
    %15 = stablehlo.add %13, %14 : tensor<24x24x6x32xf32>
    %16 = stablehlo.compare  LT, %15, %cst_2 : (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>) -> tensor<24x24x6x32xi1>
    %17 = stablehlo.select %16, %cst_2, %15 : tensor<24x24x6x32xi1>, tensor<24x24x6x32xf32>
    %18 = "stablehlo.reduce_window"(%17, %cst_5) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
      %56 = stablehlo.maximum %arg13, %arg14 : tensor<f32>
      stablehlo.return %56 : tensor<f32>
    }) : (tensor<24x24x6x32xf32>, tensor<f32>) -> tensor<12x12x6x32xf32>
    %19 = stablehlo.reverse %3, dims = [0, 1] : tensor<5x5x6x16xf32>
    %20 = stablehlo.convolution(%18, %19) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x32xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x32xf32>
    %21 = stablehlo.broadcast_in_dim %1#3, dims = [2] : (tensor<16xf32>) -> tensor<8x8x16x32xf32>
    %22 = stablehlo.add %20, %21 : tensor<8x8x16x32xf32>
    %23 = stablehlo.compare  LT, %22, %cst_1 : (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>) -> tensor<8x8x16x32xi1>
    %24 = stablehlo.select %23, %cst_1, %22 : tensor<8x8x16x32xi1>, tensor<8x8x16x32xf32>
    %25 = "stablehlo.reduce_window"(%24, %cst_5) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
      %56 = stablehlo.maximum %arg13, %arg14 : tensor<f32>
      stablehlo.return %56 : tensor<f32>
    }) : (tensor<8x8x16x32xf32>, tensor<f32>) -> tensor<4x4x16x32xf32>
    %26 = stablehlo.transpose %25, dims = [3, 2, 1, 0] : (tensor<4x4x16x32xf32>) -> tensor<32x16x4x4xf32>
    %27 = stablehlo.reshape %26 : (tensor<32x16x4x4xf32>) -> tensor<32x256xf32>
    %28 = stablehlo.dot_general %1#4, %27, contracting_dims = [0] x [1] : (tensor<256x128xf32>, tensor<32x256xf32>) -> tensor<128x32xf32>
    %29 = stablehlo.broadcast_in_dim %1#5, dims = [0] : (tensor<128xf32>) -> tensor<128x32xf32>
    %30 = stablehlo.add %28, %29 : tensor<128x32xf32>
    %31 = stablehlo.compare  LT, %30, %cst_0 : (tensor<128x32xf32>, tensor<128x32xf32>) -> tensor<128x32xi1>
    %32 = stablehlo.select %31, %cst_0, %30 : tensor<128x32xi1>, tensor<128x32xf32>
    %33 = stablehlo.dot_general %1#6, %32, contracting_dims = [0] x [0] : (tensor<128x84xf32>, tensor<128x32xf32>) -> tensor<84x32xf32>
    %34 = stablehlo.broadcast_in_dim %1#7, dims = [0] : (tensor<84xf32>) -> tensor<84x32xf32>
    %35 = stablehlo.add %33, %34 : tensor<84x32xf32>
    %36 = stablehlo.compare  LT, %35, %cst : (tensor<84x32xf32>, tensor<84x32xf32>) -> tensor<84x32xi1>
    %37 = stablehlo.select %36, %cst, %35 : tensor<84x32xi1>, tensor<84x32xf32>
    %38 = stablehlo.dot_general %1#8, %37, contracting_dims = [0] x [0] : (tensor<84x2xf32>, tensor<84x32xf32>) -> tensor<2x32xf32>
    %39 = stablehlo.broadcast_in_dim %1#9, dims = [0] : (tensor<2xf32>) -> tensor<2x32xf32>
    %40 = stablehlo.add %38, %39 : tensor<2x32xf32>
    %41 = stablehlo.reduce(%40 init: %cst_5) applies stablehlo.maximum across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %42 = stablehlo.broadcast_in_dim %41, dims = [1] : (tensor<32xf32>) -> tensor<2x32xf32>
    %43 = stablehlo.subtract %40, %42 : tensor<2x32xf32>
    %44 = stablehlo.exponential %43 : tensor<2x32xf32>
    %45 = stablehlo.reduce(%44 init: %cst_8) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %46 = stablehlo.log %45 : tensor<32xf32>
    %47 = stablehlo.broadcast_in_dim %46, dims = [1] : (tensor<32xf32>) -> tensor<2x32xf32>
    %48 = stablehlo.subtract %43, %47 : tensor<2x32xf32>
    %49 = stablehlo.multiply %5, %48 : tensor<2x32xf32>
    %50 = stablehlo.reduce(%49 init: %cst_8) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %51 = stablehlo.negate %50 : tensor<32xf32>
    %52 = stablehlo.reshape %51 : (tensor<32xf32>) -> tensor<1x32xf32>
    %53 = stablehlo.reduce(%52 init: %cst_8) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x32xf32>, tensor<f32>) -> tensor<f32>
    %54 = stablehlo.divide %53, %cst_4 : tensor<f32>
    %55 = stablehlo.add %11, %54 : tensor<f32>
    return %55, %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7, %1#8, %1#9, %1#10, %arg11, %1#11 : tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
  }
  func.func @main(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x1x28x28xf32>, %arg12: tensor<32x2xi1>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<6xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<16xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<128xf32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<84xf32>
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<2xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<6x1x5x5xf32>
    %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<16x6x5x5xf32>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<256x128xf32>
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<128x84xf32>
    %cst_9 = stablehlo.constant dense<0.000000e+00> : tensor<84x2xf32>
    %0:23 = enzyme.autodiff @"Const{typeof(\E2\88\82xloss_function)}(Main.\E2\88\82xloss_function)_autodiff"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %cst, %cst_5, %cst_0, %cst_6, %cst_1, %cst_7, %cst_2, %cst_8, %cst_3, %cst_9, %cst_4) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>]} : (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>)
    return %0#13, %0#14, %0#15, %0#16, %0#17, %0#18, %0#19, %0#20, %0#21, %0#22, %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12 : tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
  }
}

@wsmoses
Copy link
Member

wsmoses commented Jan 1, 2025

Yeah this is just some infra we still need to do. Copying from slack:

In essence we should add an arg to the enzyme pass which is a string of optimization passes to run on newly generated functions after creating them, modifying the pass like in lowerkernel.cpp in EnzymeJaX to also emit the right registrations for those new passes, then after EnzymeMLIR finishes a new function, run the passes in that string (examples of this code are in run_pass_pipeline in EnzymeJaX), then in reactant.jl we pass the “default optimization set including remove enzyme ops” pass as this string arg

One other thing that would also fix and is also separately probably worth doing: adding autodiff interfaces to enzyme push pop etc

@wsmoses
Copy link
Member

wsmoses commented Jan 1, 2025

x/ref EnzymeAD/Enzyme#2214

@wsmoses
Copy link
Member

wsmoses commented Jan 1, 2025

Pending jll (untested but with support for post passes): JuliaPackaging/Yggdrasil#10190

We'll then need to add the post opt flags to the enzyme pass in our compile pipeline

@wsmoses
Copy link
Member

wsmoses commented Jan 2, 2025

done #452

@wsmoses wsmoses closed this as completed Jan 2, 2025
@avik-pal
Copy link
Collaborator Author

avik-pal commented Jan 2, 2025

We still need to implement adjoint for select_and_scatter

error: could not compute the adjoint for this operation 
%120 = "stablehlo.select_and_scatter"(%24, %119, %3) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg16: tensor<f32>, %arg17: tensor<f32>):
  %134 = "stablehlo.compare"(%arg16, %arg17) <{comparison_direction = #stablehlo<comparison_direction GE>}> : (tensor<f32>, tensor<f32>) -> tensor<i1>
  "stablehlo.return"(%134) : (tensor<i1>) -> ()
}, {
^bb0(%arg14: tensor<f32>, %arg15: tensor<f32>):
  %133 = "stablehlo.add"(%arg14, %arg15) : (tensor<f32>, tensor<f32>) -> tensor<f32>
  "stablehlo.return"(%133) : (tensor<f32>) -> ()
}) : (tensor<24x24x6x32xf32>, tensor<12x12x6x32xf32>, tensor<f32>) -> tensor<24x24x6x32xf32>
error: could not compute the adjoint for this operation 
%107 = "stablehlo.select_and_scatter"(%31, %106, %3) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg20: tensor<f32>, %arg21: tensor<f32>):
  %136 = "stablehlo.compare"(%arg20, %arg21) <{comparison_direction = #stablehlo<comparison_direction GE>}> : (tensor<f32>, tensor<f32>) -> tensor<i1>
  "stablehlo.return"(%136) : (tensor<i1>) -> ()
}, {
^bb0(%arg18: tensor<f32>, %arg19: tensor<f32>):
  %135 = "stablehlo.add"(%arg18, %arg19) : (tensor<f32>, tensor<f32>) -> tensor<f32>
  "stablehlo.return"(%135) : (tensor<f32>) -> ()
}) : (tensor<8x8x16x32xf32>, tensor<4x4x16x32xf32>, tensor<f32>) -> tensor<8x8x16x32xf32>

@wsmoses wsmoses reopened this Jan 2, 2025
@wsmoses
Copy link
Member

wsmoses commented Jan 2, 2025

Yeah fair enough, but that means that second order now works.

@avik-pal do you want to give scatter and related ops a go?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants