Skip to content

Commit

Permalink
Polish hand.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Oct 20, 2024
1 parent 074bbc1 commit 46feea0
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 21 deletions.
12 changes: 7 additions & 5 deletions adbench/hand.fut
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,14 @@ entry calculate_jacobian [num_bones][N][M][num_us]
triangles,
is_mirrored }
let us_derivs = if num_us == 0 then 0 else 2
let f i =
let (theta',us') = onehot.onehot (onehot.(pair (arr f64) (arr f64))) i
let us'' = sized num_us (flatten (replicate (num_us/2) us') : [num_us/2*2]f64)
let oh : onehot.gen [theta_count*1+us_derivs]
([theta_count]f64,[num_us/2][2]f64) =
onehot.(pair (arr f64) (resize (cycle (arr f64))))
let f (theta',us') =
let us'' = sized num_us (flatten us')
in jvp (uncurry (objective model correspondences points))
(theta,us) (trace (theta',us''))
let J = map flatten (map f (iota (theta_count+us_derivs)))
(theta,us) (theta',us'')
let J = map flatten (map f (onehots oh))

in if num_us == 0
then J
Expand Down
21 changes: 21 additions & 0 deletions adbench/lib/github.com/diku-dk/autodiff/autodiff.fut
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
-- | Various utilities for performing AD.

import "onehot"

local def singular 'a (x: onehot.gen [1] a) = onehot.onehot x 0

-- | Compute the gradient of a scalar-valued function given a one-hot
-- generator for its result.
def grad_unit gen f x = vjp f x (singular gen)

-- | Convenience function for computing the gradient of an
-- 'f64'-valued differentiable function.
def grad32 = grad_unit onehot.f32

-- | Convenience function for computing the gradient of an
-- 'f64'-valued differentiable function.
def grad64 = grad_unit onehot.f64

-- | Compute the gradient of an arbitrary differentiable function
-- given a one-hot generator for its result.
def grad_rev gen f x = map (vjp f x) (onehots gen)
45 changes: 29 additions & 16 deletions adbench/lib/github.com/diku-dk/autodiff/onehot.fut
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ module type onehot = {
-- value of type 'a'.
val size [n] 'a : gen [n] a -> i64

-- | Resize the generation space. This does not affect the actual
-- values generated.
val resize [n][m] 'a : gen [n] a -> gen [m] a

-- | Produces a generator that is `one` at index 0 and `zero`
-- everywhere else.
val point 'a : (one: a) -> (zero: a) -> gen [1] a
Expand Down Expand Up @@ -64,6 +68,9 @@ module type onehot = {
-- Polymorphic in the array size, which will be inferred at the
-- usage site.
val arr [n][m] 'a : gen [m] a -> gen [n*m] ([n]a)

-- | Repeats the elements of a generator.
val cycle [n][r] 'a : gen [n] a -> gen [n] ([r]a)
}

module onehot : onehot = {
Expand All @@ -77,23 +84,10 @@ module onehot : onehot = {
def onehot 'a (gen: gen [] a) i = gen.gen i
def size [n] 'a (_: gen [n] a) = n

def resize [n][m] 'a (gen: gen [n] a) = {size = witness m, gen = gen.gen}

def point one zero = {size = witness 1,
gen = \i -> if i == 0i64 then one else zero}

def bool = point true false
def i8 = point 1i8 0i8
def i16 = point 1i16 0i16
def i32 = point 1i32 0i32
def i64 = point 1i64 0i64
def u8 = point 1u8 0u8
def u16 = point 1u16 0u16
def u32 = point 1u32 0u32
def u64 = point 1u64 0u64

def f16 = point 1f16 0f16
def f32 = point 1f32 0f32
def f64 = point 1f64 0f64

def fixed a = { size = witness 0, gen = const a }

def pair [n][m] 'a 'b (x: gen[n]a) (y: gen[m]b) =
Expand All @@ -112,9 +106,28 @@ module onehot : onehot = {
{ size = witness (n*m),
gen = \i -> tabulate n (\l ->
if i / m == l
then onehot gen (i % m)
then onehot gen (i %% m)
else onehot gen (-1))
}

def cycle [n][r] 'a (gen: gen [n] a) =
{ size = witness n,
gen = \i -> replicate r (gen.gen (if i < 0 then -1 else i%%n))
}

def bool = point true false
def i8 = point 1i8 0i8
def i16 = point 1i16 0i16
def i32 = point 1i32 0i32
def i64 = point 1i64 0i64
def u8 = point 1u8 0u8
def u16 = point 1u16 0u16
def u32 = point 1u32 0u32
def u64 = point 1u64 0u64

def f16 = point 1f16 0f16
def f32 = point 1f32 0f32
def f64 = point 1f64 0f64
}

-- | Generate all one-hot values possible for a given generator.
Expand Down
19 changes: 19 additions & 0 deletions adbench/lib/github.com/diku-dk/autodiff/onehot_tests.fut
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,22 @@ entry test_arr_pair n =
-- [0, 0, 0],
-- [0, 0, 1]]
-- }

-- ==
-- entry: test_cycle
-- input { 3i64 }
-- output {
-- [1.0, 0.0, 0.0]
-- [[[0, 0],
-- [0, 0],
-- [0, 0]],
-- [[1, 0],
-- [1, 0],
-- [1, 0]],
-- [[0, 1],
-- [0, 1],
-- [0, 1]]]
-- }

entry test_cycle n : ([]f64,[][n][2]i32) =
unzip (onehots (onehot.(pair f64 (cycle (arr i32)))))

0 comments on commit 46feea0

Please sign in to comment.