Skip to content

Commit

Permalink
Adding airthmetic operations
Browse files Browse the repository at this point in the history
  • Loading branch information
pavanky committed Feb 27, 2017
1 parent 9cbe920 commit d101a1f
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 13 deletions.
1 change: 1 addition & 0 deletions arrayfire.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ require('arrayfire.defines')
require('arrayfire.dim4')
require('arrayfire.util')
require('arrayfire.array')
require('arrayfire.arith')
require('arrayfire.device')

return af
190 changes: 190 additions & 0 deletions arrayfire/arith.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
require('arrayfire.lib')
require('arrayfire.defines')
require('arrayfire.array')
local ffi = require( "ffi" )

local funcs = {}

funcs[30] = [[
af_err af_add (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_sub (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_mul (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_div (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_lt (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_gt (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_le (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_ge (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_eq (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_neq (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_and (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_or (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_not (af_array *out, const af_array in);
af_err af_bitand (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_bitor (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_bitxor (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_bitshiftl(af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_bitshiftr(af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_cast (af_array *out, const af_array in, const af_dtype type);
af_err af_minof (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_maxof (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_rem (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_mod (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_abs (af_array *out, const af_array in);
af_err af_arg (af_array *out, const af_array in);
af_err af_sign (af_array *out, const af_array in);
af_err af_round (af_array *out, const af_array in);
af_err af_trunc (af_array *out, const af_array in);
af_err af_floor (af_array *out, const af_array in);
af_err af_ceil (af_array *out, const af_array in);
af_err af_hypot (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_sin (af_array *out, const af_array in);
af_err af_cos (af_array *out, const af_array in);
af_err af_tan (af_array *out, const af_array in);
af_err af_asin (af_array *out, const af_array in);
af_err af_acos (af_array *out, const af_array in);
af_err af_atan (af_array *out, const af_array in);
af_err af_atan2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_cplx2 (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_cplx (af_array *out, const af_array in);
af_err af_real (af_array *out, const af_array in);
af_err af_imag (af_array *out, const af_array in);
af_err af_conjg (af_array *out, const af_array in);
af_err af_sinh (af_array *out, const af_array in);
af_err af_cosh (af_array *out, const af_array in);
af_err af_tanh (af_array *out, const af_array in);
af_err af_asinh (af_array *out, const af_array in);
af_err af_acosh (af_array *out, const af_array in);
af_err af_atanh (af_array *out, const af_array in);
af_err af_root (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_pow (af_array *out, const af_array lhs, const af_array rhs, const bool batch);
af_err af_pow2 (af_array *out, const af_array in);
af_err af_exp (af_array *out, const af_array in);
af_err af_expm1 (af_array *out, const af_array in);
af_err af_erf (af_array *out, const af_array in);
af_err af_erfc (af_array *out, const af_array in);
af_err af_log (af_array *out, const af_array in);
af_err af_log1p (af_array *out, const af_array in);
af_err af_log10 (af_array *out, const af_array in);
af_err af_log2 (af_array *out, const af_array in);
af_err af_sqrt (af_array *out, const af_array in);
af_err af_cbrt (af_array *out, const af_array in);
af_err af_factorial (af_array *out, const af_array in);
af_err af_tgamma (af_array *out, const af_array in);
af_err af_lgamma (af_array *out, const af_array in);
af_err af_iszero (af_array *out, const af_array in);
af_err af_isinf (af_array *out, const af_array in);
af_err af_isnan (af_array *out, const af_array in);
]]

funcs[31] = [[
af_err af_sigmoid (af_array *out, const af_array in);
]]

funcs[34] = [[
af_err af_clamp(af_array *out, const af_array in,
const af_array lo, const af_array hi, const bool batch);
]]

af.lib.cdef(funcs)
local c_array_p = af.ffi.c_array_p
local init = af.Array.init

local binaryFuncs = {
'add',
'sub',
'mul',
'div',
'lt',
'gt',
'le',
'ge',
'eq',
'neq',
'and',
'or',
'bitand',
'bitor',
'bitxor',
'bitshiftl',
'bitshiftr',
'minof',
'maxof',
'rem',
'mod',
'hypot',
'atan2',
'cplx2',
'root',
'pow',
}


for _, func in ipairs(binaryFuncs) do
af[func] = function(lhs, rhs, batch)
-- TODO: add support for numbers
-- TODO: add support for batch mode
local res = c_array_p()
af.clib['af_' .. func](res, lhs:get(), rhs:get(), batch and true or false)
return init(res[0])
end
end

local unaryFuncs = {
'abs',
'arg',
'sign',
'round',
'trunc',
'floor',
'ceil',
'sin',
'cos',
'tan',
'asin',
'acos',
'atan',
'cplx',
'real',
'imag',
'conjg',
'sinh',
'cosh',
'tanh',
'asinh',
'acosh',
'atanh',
'pow2',
'exp',
'expm1',
'erf',
'erfc',
'log',
'log1p',
'log10',
'log2',
'sqrt',
'cbrt',
'factorial',
'tgamma',
'lgamma',
'iszero',
'isinf',
'isnan'
}

for _, func in ipairs(unaryFuncs) do
af[func] = function(input)
-- TODO: add support for numbers
-- TODO: add support for batch mode
local res = c_array_p()
af.clib['af_' .. func](res, input:get())
return init(res[0])
end
end

af.cast = function(input, rtype)
local res = c_array_p()
af.clib.af_cast(res, input:get(), rtype)
return init(res[0])
end
28 changes: 16 additions & 12 deletions arrayfire/array.lua
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,19 @@ local c_uint_t = af.ffi.c_uint_t
local c_ptr_t = af.ffi.c_ptr_t
local Dim4 = af.Dim4

function release_array(ptr)
local res = af.clib.af_release_array(ptr)
-- TODO: Error handling logic
end

local c_array_p = function(ptr)
local arr_ptr = ffi.new('void *[1]', ptr)
arr_ptr[0] = ffi.gc(arr_ptr[0], af.clib.af_release_array)
return arr_ptr
end

local init = function(ptr)
local self = setmetatable({}, Array)
self._array = ptr
self._ptr = ffi.gc(ptr, release_array)
return self
end

Expand Down Expand Up @@ -117,51 +121,51 @@ Array.__tostring = function(self)
end

Array.get = function(self)
return self._array
return self._ptr
end

-- TODO: implement Array.write

Array.copy = function(self)
local res = c_array_p()
af.clib.af_copy_array(res, self._array)
af.clib.af_copy_array(res, self:get())
return Array.init(res[0])
end

Array.softCopy = function(self)
local res = c_array_p()
af.clib.af_copy_array(res, self._array)
af.clib.af_copy_array(res, self:get())
return Array.init(res[0])
end

Array.elements = function(self)
local res = c_ptr_t('dim_t')
af.clib.af_get_elements(res, self._array)
af.clib.af_get_elements(res, self:get())
return tonumber(res[0])
end

Array.type = function(self)
local res = c_ptr_t('af_dtype')
af.clib.af_get_type(res, self._array)
af.clib.af_get_type(res, self:get())
return tonumber(res[0])
end

Array.typeName = function(self)
local res = c_ptr_t('af_dtype')
af.clib.af_get_type(res, self._array)
af.clib.af_get_type(res, self:get())
return af.dtype_names[tonumber(res[0])]
end

Array.dims = function(self)
local res = c_dim4_t()
af.clib.af_get_dims(res + 0, res + 1, res + 2, res + 3, self._array)
af.clib.af_get_dims(res + 0, res + 1, res + 2, res + 3, self:get())
return Dim4(tonumber(res[0]), tonumber(res[1]),
tonumber(res[2]), tonumber(res[3]))
end

Array.numdims = function(self)
local res = c_ptr_t('unsigned int')
af.clib.af_get_numdims(res, self._array)
af.clib.af_get_numdims(res, self:get())
return tonumber(res[0])
end

Expand All @@ -184,13 +188,13 @@ local funcs = {
for name, cname in pairs(funcs) do
Array[name] = function(self)
local res = c_ptr_t('bool')
af.clib['af_' .. cname](res, self._array)
af.clib['af_' .. cname](res, self:get())
return res[0]
end
end

Array.eval = function(self)
af.clib.af_eval(self._array)
af.clib.af_eval(self:get())
end

-- Useful aliases
Expand Down
2 changes: 1 addition & 1 deletion arrayfire/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ funcs[34] = [[
af.lib.cdef(funcs)

af.print = function(arr)
af.clib.af_print_array_gen(ffi.cast("char *", "ArrayFire Array"), arr._array, 4)
af.clib.af_print_array_gen(ffi.cast("char *", "ArrayFire Array"), arr:get(), 4)
end
1 change: 1 addition & 0 deletions rocks/arrayfire-scm-1.rockspec
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ build = {
["arrayfire.defines"] = "arrayfire/defines.lua",
["arrayfire.device"] = "arrayfire/device.lua",
["arrayfire.dim4"] = "arrayfire/dim4.lua",
["arrayfire.arith"] = "arrayfire/arith.lua",
},
}

0 comments on commit d101a1f

Please sign in to comment.