-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[2nd Order AD] Regularization term in loss function #449
Comments
a bit more minimal IR module {
func.func private @"Const{typeof(loss_function)}(Main.loss_function)_autodiff"(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x2xi1>) -> (tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<84x32xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x32xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x32xf32>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x32xf32>
%cst_3 = stablehlo.constant dense<3.200000e+01> : tensor<f32>
%cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%cst_5 = stablehlo.constant dense<0xFF800000> : tensor<f32>
%0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
%1 = stablehlo.transpose %arg2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
%2 = stablehlo.convert %arg11 : (tensor<32x2xi1>) -> tensor<32x2xf32>
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<32x2xf32>) -> tensor<2x32xf32>
%4 = stablehlo.reverse %0, dims = [0, 1] : tensor<5x5x1x6xf32>
%5 = stablehlo.convolution(%arg10, %4) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<32x1x28x28xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x32xf32>
%6 = stablehlo.broadcast_in_dim %arg1, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x32xf32>
%7 = stablehlo.add %5, %6 : tensor<24x24x6x32xf32>
%8 = stablehlo.compare LT, %7, %cst_2 : (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>) -> tensor<24x24x6x32xi1>
%9 = stablehlo.select %8, %cst_2, %7 : tensor<24x24x6x32xi1>, tensor<24x24x6x32xf32>
%10 = "stablehlo.reduce_window"(%9, %cst_5) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg12: tensor<f32>, %arg13: tensor<f32>):
%47 = stablehlo.maximum %arg12, %arg13 : tensor<f32>
stablehlo.return %47 : tensor<f32>
}) : (tensor<24x24x6x32xf32>, tensor<f32>) -> tensor<12x12x6x32xf32>
%11 = stablehlo.reverse %1, dims = [0, 1] : tensor<5x5x6x16xf32>
%12 = stablehlo.convolution(%10, %11) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x32xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x32xf32>
%13 = stablehlo.broadcast_in_dim %arg3, dims = [2] : (tensor<16xf32>) -> tensor<8x8x16x32xf32>
%14 = stablehlo.add %12, %13 : tensor<8x8x16x32xf32>
%15 = stablehlo.compare LT, %14, %cst_1 : (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>) -> tensor<8x8x16x32xi1>
%16 = stablehlo.select %15, %cst_1, %14 : tensor<8x8x16x32xi1>, tensor<8x8x16x32xf32>
%17 = "stablehlo.reduce_window"(%16, %cst_5) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg12: tensor<f32>, %arg13: tensor<f32>):
%47 = stablehlo.maximum %arg12, %arg13 : tensor<f32>
stablehlo.return %47 : tensor<f32>
}) : (tensor<8x8x16x32xf32>, tensor<f32>) -> tensor<4x4x16x32xf32>
%18 = stablehlo.transpose %17, dims = [3, 2, 1, 0] : (tensor<4x4x16x32xf32>) -> tensor<32x16x4x4xf32>
%19 = stablehlo.reshape %18 : (tensor<32x16x4x4xf32>) -> tensor<32x256xf32>
%20 = stablehlo.dot_general %arg4, %19, contracting_dims = [0] x [1] : (tensor<256x128xf32>, tensor<32x256xf32>) -> tensor<128x32xf32>
%21 = stablehlo.broadcast_in_dim %arg5, dims = [0] : (tensor<128xf32>) -> tensor<128x32xf32>
%22 = stablehlo.add %20, %21 : tensor<128x32xf32>
%23 = stablehlo.compare LT, %22, %cst_0 : (tensor<128x32xf32>, tensor<128x32xf32>) -> tensor<128x32xi1>
%24 = stablehlo.select %23, %cst_0, %22 : tensor<128x32xi1>, tensor<128x32xf32>
%25 = stablehlo.dot_general %arg6, %24, contracting_dims = [0] x [0] : (tensor<128x84xf32>, tensor<128x32xf32>) -> tensor<84x32xf32>
%26 = stablehlo.broadcast_in_dim %arg7, dims = [0] : (tensor<84xf32>) -> tensor<84x32xf32>
%27 = stablehlo.add %25, %26 : tensor<84x32xf32>
%28 = stablehlo.compare LT, %27, %cst : (tensor<84x32xf32>, tensor<84x32xf32>) -> tensor<84x32xi1>
%29 = stablehlo.select %28, %cst, %27 : tensor<84x32xi1>, tensor<84x32xf32>
%30 = stablehlo.dot_general %arg8, %29, contracting_dims = [0] x [0] : (tensor<84x2xf32>, tensor<84x32xf32>) -> tensor<2x32xf32>
%31 = stablehlo.broadcast_in_dim %arg9, dims = [0] : (tensor<2xf32>) -> tensor<2x32xf32>
%32 = stablehlo.add %30, %31 : tensor<2x32xf32>
%33 = stablehlo.reduce(%32 init: %cst_5) applies stablehlo.maximum across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
%34 = stablehlo.broadcast_in_dim %33, dims = [1] : (tensor<32xf32>) -> tensor<2x32xf32>
%35 = stablehlo.subtract %32, %34 : tensor<2x32xf32>
%36 = stablehlo.exponential %35 : tensor<2x32xf32>
%37 = stablehlo.reduce(%36 init: %cst_4) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
%38 = stablehlo.log %37 : tensor<32xf32>
%39 = stablehlo.broadcast_in_dim %38, dims = [1] : (tensor<32xf32>) -> tensor<2x32xf32>
%40 = stablehlo.subtract %35, %39 : tensor<2x32xf32>
%41 = stablehlo.multiply %3, %40 : tensor<2x32xf32>
%42 = stablehlo.reduce(%41 init: %cst_4) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
%43 = stablehlo.negate %42 : tensor<32xf32>
%44 = stablehlo.reshape %43 : (tensor<32xf32>) -> tensor<1x32xf32>
%45 = stablehlo.reduce(%44 init: %cst_4) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x32xf32>, tensor<f32>) -> tensor<f32>
%46 = stablehlo.divide %45, %cst_3 : tensor<f32>
return %46, %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11 : tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
}
func.func private @"Const{typeof(\E2\88\82xloss_function)}(Main.\E2\88\82xloss_function)_autodiff"(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x1x28x28xf32>, %arg12: tensor<32x2xi1>) -> (tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<84x32xf32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<128x32xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x32xf32>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x32xf32>
%cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<32x1x28x28xf32>
%cst_4 = stablehlo.constant dense<3.200000e+01> : tensor<f32>
%cst_5 = stablehlo.constant dense<0xFF800000> : tensor<f32>
%cst_6 = stablehlo.constant dense<2.508800e+04> : tensor<f32>
%cst_7 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%0 = stablehlo.transpose %arg11, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
%1:13 = enzyme.autodiff @"Const{typeof(loss_function)}(Main.loss_function)_autodiff"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg12, %cst_7, %cst_3) {activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<f32>, tensor<32x1x28x28xf32>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<32x1x28x28xf32>)
%2 = stablehlo.transpose %1#0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
%3 = stablehlo.transpose %1#2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
%4 = stablehlo.convert %1#11 : (tensor<32x2xi1>) -> tensor<32x2xf32>
%5 = stablehlo.transpose %4, dims = [1, 0] : (tensor<32x2xf32>) -> tensor<2x32xf32>
%6 = stablehlo.transpose %1#12, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
%7 = stablehlo.subtract %6, %0 : tensor<28x28x1x32xf32>
%8 = stablehlo.abs %7 : tensor<28x28x1x32xf32>
%9 = stablehlo.multiply %8, %8 : tensor<28x28x1x32xf32>
%10 = stablehlo.reduce(%9 init: %cst_8) applies stablehlo.add across dimensions = [0, 1, 2, 3] : (tensor<28x28x1x32xf32>, tensor<f32>) -> tensor<f32>
%11 = stablehlo.divide %10, %cst_6 : tensor<f32>
%12 = stablehlo.reverse %2, dims = [0, 1] : tensor<5x5x1x6xf32>
%13 = stablehlo.convolution(%1#10, %12) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<32x1x28x28xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x32xf32>
%14 = stablehlo.broadcast_in_dim %1#1, dims = [2] : (tensor<6xf32>) -> tensor<24x24x6x32xf32>
%15 = stablehlo.add %13, %14 : tensor<24x24x6x32xf32>
%16 = stablehlo.compare LT, %15, %cst_2 : (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>) -> tensor<24x24x6x32xi1>
%17 = stablehlo.select %16, %cst_2, %15 : tensor<24x24x6x32xi1>, tensor<24x24x6x32xf32>
%18 = "stablehlo.reduce_window"(%17, %cst_5) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
%56 = stablehlo.maximum %arg13, %arg14 : tensor<f32>
stablehlo.return %56 : tensor<f32>
}) : (tensor<24x24x6x32xf32>, tensor<f32>) -> tensor<12x12x6x32xf32>
%19 = stablehlo.reverse %3, dims = [0, 1] : tensor<5x5x6x16xf32>
%20 = stablehlo.convolution(%18, %19) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x32xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x32xf32>
%21 = stablehlo.broadcast_in_dim %1#3, dims = [2] : (tensor<16xf32>) -> tensor<8x8x16x32xf32>
%22 = stablehlo.add %20, %21 : tensor<8x8x16x32xf32>
%23 = stablehlo.compare LT, %22, %cst_1 : (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>) -> tensor<8x8x16x32xi1>
%24 = stablehlo.select %23, %cst_1, %22 : tensor<8x8x16x32xi1>, tensor<8x8x16x32xf32>
%25 = "stablehlo.reduce_window"(%24, %cst_5) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
%56 = stablehlo.maximum %arg13, %arg14 : tensor<f32>
stablehlo.return %56 : tensor<f32>
}) : (tensor<8x8x16x32xf32>, tensor<f32>) -> tensor<4x4x16x32xf32>
%26 = stablehlo.transpose %25, dims = [3, 2, 1, 0] : (tensor<4x4x16x32xf32>) -> tensor<32x16x4x4xf32>
%27 = stablehlo.reshape %26 : (tensor<32x16x4x4xf32>) -> tensor<32x256xf32>
%28 = stablehlo.dot_general %1#4, %27, contracting_dims = [0] x [1] : (tensor<256x128xf32>, tensor<32x256xf32>) -> tensor<128x32xf32>
%29 = stablehlo.broadcast_in_dim %1#5, dims = [0] : (tensor<128xf32>) -> tensor<128x32xf32>
%30 = stablehlo.add %28, %29 : tensor<128x32xf32>
%31 = stablehlo.compare LT, %30, %cst_0 : (tensor<128x32xf32>, tensor<128x32xf32>) -> tensor<128x32xi1>
%32 = stablehlo.select %31, %cst_0, %30 : tensor<128x32xi1>, tensor<128x32xf32>
%33 = stablehlo.dot_general %1#6, %32, contracting_dims = [0] x [0] : (tensor<128x84xf32>, tensor<128x32xf32>) -> tensor<84x32xf32>
%34 = stablehlo.broadcast_in_dim %1#7, dims = [0] : (tensor<84xf32>) -> tensor<84x32xf32>
%35 = stablehlo.add %33, %34 : tensor<84x32xf32>
%36 = stablehlo.compare LT, %35, %cst : (tensor<84x32xf32>, tensor<84x32xf32>) -> tensor<84x32xi1>
%37 = stablehlo.select %36, %cst, %35 : tensor<84x32xi1>, tensor<84x32xf32>
%38 = stablehlo.dot_general %1#8, %37, contracting_dims = [0] x [0] : (tensor<84x2xf32>, tensor<84x32xf32>) -> tensor<2x32xf32>
%39 = stablehlo.broadcast_in_dim %1#9, dims = [0] : (tensor<2xf32>) -> tensor<2x32xf32>
%40 = stablehlo.add %38, %39 : tensor<2x32xf32>
%41 = stablehlo.reduce(%40 init: %cst_5) applies stablehlo.maximum across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
%42 = stablehlo.broadcast_in_dim %41, dims = [1] : (tensor<32xf32>) -> tensor<2x32xf32>
%43 = stablehlo.subtract %40, %42 : tensor<2x32xf32>
%44 = stablehlo.exponential %43 : tensor<2x32xf32>
%45 = stablehlo.reduce(%44 init: %cst_8) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
%46 = stablehlo.log %45 : tensor<32xf32>
%47 = stablehlo.broadcast_in_dim %46, dims = [1] : (tensor<32xf32>) -> tensor<2x32xf32>
%48 = stablehlo.subtract %43, %47 : tensor<2x32xf32>
%49 = stablehlo.multiply %5, %48 : tensor<2x32xf32>
%50 = stablehlo.reduce(%49 init: %cst_8) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
%51 = stablehlo.negate %50 : tensor<32xf32>
%52 = stablehlo.reshape %51 : (tensor<32xf32>) -> tensor<1x32xf32>
%53 = stablehlo.reduce(%52 init: %cst_8) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x32xf32>, tensor<f32>) -> tensor<f32>
%54 = stablehlo.divide %53, %cst_4 : tensor<f32>
%55 = stablehlo.add %11, %54 : tensor<f32>
return %55, %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7, %1#8, %1#9, %1#10, %arg11, %1#11 : tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
}
func.func @main(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x1x28x28xf32>, %arg12: tensor<32x2xi1>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<6xf32>
%cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<16xf32>
%cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<128xf32>
%cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<84xf32>
%cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<2xf32>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<6x1x5x5xf32>
%cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<16x6x5x5xf32>
%cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<256x128xf32>
%cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<128x84xf32>
%cst_9 = stablehlo.constant dense<0.000000e+00> : tensor<84x2xf32>
%0:23 = enzyme.autodiff @"Const{typeof(\E2\88\82xloss_function)}(Main.\E2\88\82xloss_function)_autodiff"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %cst, %cst_5, %cst_0, %cst_6, %cst_1, %cst_7, %cst_2, %cst_8, %cst_3, %cst_9, %cst_4) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>]} : (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>)
return %0#13, %0#14, %0#15, %0#16, %0#17, %0#18, %0#19, %0#20, %0#21, %0#22, %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12 : tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
}
} |
Yeah this is just some infra we still need to do. Copying from slack: In essence we should add an arg to the enzyme pass which is a string of optimization passes to run on newly generated functions after creating them, modifying the pass like in lowerkernel.cpp in EnzymeJaX to also emit the right registrations for those new passes, then after EnzymeMLIR finishes a new function, run the passes in that string (examples of this code are in run_pass_pipeline in EnzymeJaX), then in reactant.jl we pass the “default optimization set including remove enzyme ops” pass as this string arg One other thing that would also fix and is also separately probably worth doing: adding autodiff interfaces to enzyme push pop etc |
x/ref EnzymeAD/Enzyme#2214 |
Pending jll (untested but with support for post passes): JuliaPackaging/Yggdrasil#10190 We'll then need to add the post opt flags to the enzyme pass in our compile pipeline |
done #452 |
We still need to implement adjoint for error: could not compute the adjoint for this operation
%120 = "stablehlo.select_and_scatter"(%24, %119, %3) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg16: tensor<f32>, %arg17: tensor<f32>):
%134 = "stablehlo.compare"(%arg16, %arg17) <{comparison_direction = #stablehlo<comparison_direction GE>}> : (tensor<f32>, tensor<f32>) -> tensor<i1>
"stablehlo.return"(%134) : (tensor<i1>) -> ()
}, {
^bb0(%arg14: tensor<f32>, %arg15: tensor<f32>):
%133 = "stablehlo.add"(%arg14, %arg15) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%133) : (tensor<f32>) -> ()
}) : (tensor<24x24x6x32xf32>, tensor<12x12x6x32xf32>, tensor<f32>) -> tensor<24x24x6x32xf32>
error: could not compute the adjoint for this operation
%107 = "stablehlo.select_and_scatter"(%31, %106, %3) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
^bb0(%arg20: tensor<f32>, %arg21: tensor<f32>):
%136 = "stablehlo.compare"(%arg20, %arg21) <{comparison_direction = #stablehlo<comparison_direction GE>}> : (tensor<f32>, tensor<f32>) -> tensor<i1>
"stablehlo.return"(%136) : (tensor<i1>) -> ()
}, {
^bb0(%arg18: tensor<f32>, %arg19: tensor<f32>):
%135 = "stablehlo.add"(%arg18, %arg19) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"stablehlo.return"(%135) : (tensor<f32>) -> ()
}) : (tensor<8x8x16x32xf32>, tensor<4x4x16x32xf32>, tensor<f32>) -> tensor<8x8x16x32xf32> |
Yeah fair enough, but that means that second order now works. @avik-pal do you want to give scatter and related ops a go? |
xref https://discourse.julialang.org/t/second-order-gradient-with-lux-zygote-cuda-enzyme/124301
Unoptimized IR
The text was updated successfully, but these errors were encountered: