Skip to content

Commit

Permalink
simplify examples
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper committed Oct 8, 2024
1 parent b50f197 commit f3355f7
Show file tree
Hide file tree
Showing 16 changed files with 160 additions and 84 deletions.
10 changes: 5 additions & 5 deletions bench/enif_merge_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,21 @@ defmodule ENIFMergeSort do
func.return
end

defm sort(env, list, err) :: Term.t() do
@err %ArgumentError{message: "list expected"}
defm sort(env, list) :: Term.t() do
len_ptr = Pointer.allocate(i32())

cond_br(enif_get_list_length(env, list, len_ptr) != 0) do
if enif_get_list_length(env, list, len_ptr) != 0 do
movable_list_ptr = Pointer.allocate(Term.t())
Pointer.store(list, movable_list_ptr)
len = Pointer.load(i32(), len_ptr)
arr = Pointer.allocate(Term.t(), len)
call ENIFTimSort.copy_terms(env, movable_list_ptr, arr)
zero = const 0 :: i32()
do_sort(arr, zero, len - 1)
ret = enif_make_list_from_array(env, arr, len)
func.return(ret)
enif_make_list_from_array(env, arr, len)
else
func.return(err)
enif_raise_exception(env, @err)
end
end
end
15 changes: 5 additions & 10 deletions bench/enif_quick_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ defmodule ENIFQuickSort do
Pointer.store(val_a, b)
val_tmp = Pointer.load(Term.t(), tmp)
Pointer.store(val_tmp, a)
func.return()
end

defm partition(arr :: Pointer.t(), low :: i32(), high :: i32()) :: i32() do
Expand Down Expand Up @@ -41,8 +40,6 @@ defmodule ENIFQuickSort do
do_sort(arr, low, pi - 1)
do_sort(arr, pi + 1, high)
end

func.return()
end

defm copy_terms(env :: Env.t(), movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do
Expand All @@ -65,25 +62,23 @@ defmodule ENIFQuickSort do
Pointer.store(head_val, ith_term_ptr)
Pointer.store(i + 1, i_ptr)
end

func.return()
end

defm sort(env, list, err) :: Term.t() do
@err %ArgumentError{message: "list expected"}
defm sort(env, list) :: Term.t() do
len_ptr = Pointer.allocate(i32())

cond_br(enif_get_list_length(env, list, len_ptr) != 0) do
if enif_get_list_length(env, list, len_ptr) != 0 do
movable_list_ptr = Pointer.allocate(Term.t())
Pointer.store(list, movable_list_ptr)
len = Pointer.load(i32(), len_ptr)
arr = Pointer.allocate(Term.t(), len)
copy_terms(env, movable_list_ptr, arr)
zero = const 0 :: i32()
do_sort(arr, zero, len - 1)
ret = enif_make_list_from_array(env, arr, len)
func.return(ret)
enif_make_list_from_array(env, arr, len)
else
func.return(err)
enif_raise_exception(env, @err)
end
end
end
16 changes: 5 additions & 11 deletions bench/enif_tim_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ defmodule ENIFTimSort do
j = Pointer.load(i32(), j_ptr)
Pointer.store(temp, Pointer.element_ptr(Term.t(), arr, j + 1))
end

func.return()
end

defm tim_sort(arr :: Pointer.t(), n :: i32()) do
Expand Down Expand Up @@ -73,8 +71,6 @@ defmodule ENIFTimSort do

Pointer.store(size * 2, size_ptr)
end

func.return()
end

defm copy_terms(env :: Env.t(), movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do
Expand All @@ -97,24 +93,22 @@ defmodule ENIFTimSort do
Pointer.store(head_val, ith_term_ptr)
Pointer.store(i + 1, i_ptr)
end

func.return()
end

defm sort(env, list, err) :: Term.t() do
@err %ArgumentError{message: "list expected"}
defm sort(env, list) :: Term.t() do
len_ptr = Pointer.allocate(i32())

cond_br(enif_get_list_length(env, list, len_ptr) != 0) do
if enif_get_list_length(env, list, len_ptr) != 0 do
movable_list_ptr = Pointer.allocate(Term.t())
Pointer.store(list, movable_list_ptr)
len = Pointer.load(i32(), len_ptr)
arr = Pointer.allocate(Term.t(), len)
copy_terms(env, movable_list_ptr, arr)
tim_sort(arr, len)
ret = enif_make_list_from_array(env, arr, len)
func.return(ret)
enif_make_list_from_array(env, arr, len)
else
func.return(err)
enif_raise_exception(env, @err)
end
end
end
12 changes: 6 additions & 6 deletions bench/sort_benchmark.exs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
arr = Enum.to_list(1..10000) |> Enum.shuffle()
ENIFQuickSort.sort(arr, :arg_err)
ENIFMergeSort.sort(arr, :arg_err)
ENIFTimSort.sort(arr, :arg_err)
ENIFQuickSort.sort(arr)
ENIFMergeSort.sort(arr)
ENIFTimSort.sort(arr)

Benchee.run(
%{
"Enum.sort" => &Enum.sort/1,
"enif_quick_sort" => &ENIFQuickSort.sort(&1, :arg_err),
"enif_merge_sort" => &ENIFMergeSort.sort(&1, :arg_err),
"enif_tim_sort" => &ENIFTimSort.sort(&1, :arg_err)
"enif_quick_sort" => &ENIFQuickSort.sort(&1),
"enif_merge_sort" => &ENIFMergeSort.sort(&1),
"enif_tim_sort" => &ENIFTimSort.sort(&1)
},
inputs: %{
"array size 10" => 10,
Expand Down
26 changes: 11 additions & 15 deletions bench/vec_add_int_list.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ defmodule AddTwoIntVec do
Pointer.load(i32(), v_ptr)
|> vector.insertelement(acc, i)
end)
|> func.return()
end

defm add(env, a, b, error) :: Term.t() do
Expand All @@ -27,20 +26,17 @@ defmodule AddTwoIntVec do
v = arith.addi(v1, v2)
start = const 0 :: i32()

ret =
enif_make_list8(
env,
enif_make_int(env, vector.extractelement(v, start)),
enif_make_int(env, vector.extractelement(v, start + 1)),
enif_make_int(env, vector.extractelement(v, start + 2)),
enif_make_int(env, vector.extractelement(v, start + 3)),
enif_make_int(env, vector.extractelement(v, start + 4)),
enif_make_int(env, vector.extractelement(v, start + 5)),
enif_make_int(env, vector.extractelement(v, start + 6)),
enif_make_int(env, vector.extractelement(v, start + 7))
)

func.return(ret)
enif_make_list8(
env,
enif_make_int(env, vector.extractelement(v, start)),
enif_make_int(env, vector.extractelement(v, start + 1)),
enif_make_int(env, vector.extractelement(v, start + 2)),
enif_make_int(env, vector.extractelement(v, start + 3)),
enif_make_int(env, vector.extractelement(v, start + 4)),
enif_make_int(env, vector.extractelement(v, start + 5)),
enif_make_int(env, vector.extractelement(v, start + 6)),
enif_make_int(env, vector.extractelement(v, start + 7))
)
end

defm dummy_load_no_make(env, a, b, error) :: Term.t() do
Expand Down
23 changes: 23 additions & 0 deletions lib/charms/defm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,25 @@ defmodule Charms.Defm do
:ok
end

# if it is single block with no terminator, add a return
defp autocomplete_func_return(func) do
with [r] <- Beaver.Walker.regions(func) |> Enum.to_list(),
[b] <- Beaver.Walker.blocks(r) |> Enum.to_list(),
last_op = %MLIR.Operation{} <-
Beaver.Walker.operations(b) |> Enum.to_list() |> List.last(),
false <- MLIR.Operation.name(last_op) == "func.return" do
mlir ctx: MLIR.CAPI.mlirOperationGetContext(last_op), block: b do
results = Beaver.Walker.results(last_op) |> Enum.to_list()
Func.return(results) >>> []
end
else
_ ->
nil
end

:ok
end

defp referenced_modules(module) do
Beaver.Walker.postwalk(module, MapSet.new(), fn
%MLIR.Operation{} = op, acc ->
Expand Down Expand Up @@ -166,6 +185,10 @@ defmodule Charms.Defm do

m
|> Charms.Debug.print_ir_pass()
|> MLIR.Pass.Composer.nested(
"func.func",
{"autocomplete_func_return", "func.func", &autocomplete_func_return/1}
)
|> MLIR.Pass.Composer.nested("func.func", Charms.Defm.Pass.CreateAbsentFunc)
|> MLIR.Pass.Composer.append({"check-poison", "builtin.module", &check_poison!/1})
|> canonicalize
Expand Down
37 changes: 26 additions & 11 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,12 @@ defmodule Charms.Defm.Expander do
defp expand({fun, _meta, [left, right]}, state, env) when fun in @intrinsics do
{left, state, env} = expand(left, state, env)
{right, state, env} = expand(right, state, env)
loc = MLIR.Location.from_env(env)

{Charms.Prelude.handle_intrinsic(fun, [left, right],
ctx: state.mlir.ctx,
block: state.mlir.blk
block: state.mlir.blk,
loc: loc
), state, env}
end

Expand Down Expand Up @@ -463,6 +465,7 @@ defmodule Charms.Defm.Expander do
arity = length(args)
mfa = {module, fun, arity}
state = update_in(state.remotes, &[mfa | &1])
loc = MLIR.Location.from_env(env)

if is_atom(module) do
try do
Expand All @@ -480,8 +483,11 @@ defmodule Charms.Defm.Expander do
function_exported?(module, :__intrinsics__, 0) and fun in module.__intrinsics__() ->
{args, state, env} = expand(args, state, env)

{module.handle_intrinsic(fun, args, ctx: state.mlir.ctx, block: state.mlir.blk),
state, env}
{module.handle_intrinsic(fun, args,
ctx: state.mlir.ctx,
block: state.mlir.blk,
loc: loc
), state, env}

module == MLIR.Attribute ->
{args, state, env} = expand(args, state, env)
Expand All @@ -499,8 +505,9 @@ defmodule Charms.Defm.Expander do
attr = unquote(attr)
term_ptr = Pointer.allocate(Term.t())
size = String.length(attr)
size = value index.casts(size) :: i64()
buffer_ptr = Pointer.allocate(i8(), size)
buffer = ptr_to_memref(buffer_ptr)
buffer = ptr_to_memref(buffer_ptr, size)
memref.copy(attr, buffer)
zero = const 0 :: i32()
enif_binary_to_term(unquote(env_var), buffer_ptr, size, term_ptr, zero)
Expand Down Expand Up @@ -659,21 +666,23 @@ defmodule Charms.Defm.Expander do

## Fallback

@const_prefix "chc"
defp expand(ast, state, env) when is_binary(ast) do
s_table = state.mlir.mod |> MLIR.Operation.from_module() |> MLIR.CAPI.mlirSymbolTableCreate()
sym_name = "__const__" <> :crypto.hash(:sha256, ast)
sym_name = @const_prefix <> :crypto.hash(:sha256, ast)
found = MLIR.CAPI.mlirSymbolTableLookup(s_table, MLIR.StringRef.create(sym_name))
loc = MLIR.Location.from_env(env)

mlir ctx: state.mlir.ctx, block: MLIR.Module.body(state.mlir.mod) do
if MLIR.is_null(found) do
MemRef.global(ast, sym_name: Attribute.string(sym_name)) >>> :infer
MemRef.global(ast, sym_name: Attribute.string(sym_name), loc: loc) >>> :infer
else
found
end
|> then(
&mlir block: state.mlir.blk do
name = Attribute.flat_symbol_ref(Attribute.unwrap(&1[:sym_name]))
MemRef.get_global(name: name) >>> Attribute.unwrap(&1[:type])
MemRef.get_global(name: name, loc: loc) >>> Attribute.unwrap(&1[:type])
end
)
end
Expand Down Expand Up @@ -851,6 +860,7 @@ defmodule Charms.Defm.Expander do
true_body = Keyword.fetch!(clauses, :do)
false_body = clauses[:else]
{condition, state, env} = expand(condition, state, env)
loc = MLIR.Location.from_env(env)

v =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
Expand All @@ -861,7 +871,7 @@ defmodule Charms.Defm.Expander do
end

# TODO: doc about an expression which is a value and an operation
SCF.if [condition] do
SCF.if [condition, loc: loc] do
region do
MLIR.CAPI.mlirRegionAppendOwnedBlock(Beaver.Env.region(), b)
end
Expand Down Expand Up @@ -1056,11 +1066,13 @@ defmodule Charms.Defm.Expander do
## Helpers

defp expand_remote(_meta, Kernel, fun, args, state, env) when fun in @intrinsics do
loc = MLIR.Location.from_env(env)
{args, state, env} = expand(args, state, env)

{Charms.Prelude.handle_intrinsic(fun, args,
ctx: state.mlir.ctx,
block: state.mlir.blk
block: state.mlir.blk,
loc: loc
), state, env}
end

Expand All @@ -1086,14 +1098,16 @@ defmodule Charms.Defm.Expander do
state = update_in(state.locals, &[{fun, length(args)} | &1])
{args, state, env} = expand_list(args, state, env)
Code.ensure_loaded!(MLIR.Type)
loc = MLIR.Location.from_env(env)

if function_exported?(MLIR.Type, fun, 1) do
{apply(MLIR.Type, fun, [[ctx: state.mlir.ctx]]), state, env}
else
case i =
Charms.Prelude.handle_intrinsic(fun, args,
ctx: state.mlir.ctx,
block: state.mlir.blk
block: state.mlir.blk,
loc: loc
) do
:not_handled ->
create_call(env.module, fun, args, [], state, env)
Expand Down Expand Up @@ -1177,8 +1191,9 @@ defmodule Charms.Defm.Expander do
end
end

@var_prefix "chv"
defp uniq_mlir_var() do
Macro.var(:"chv#{System.unique_integer([:positive])}", nil)
Macro.var(:"#{@var_prefix}#{System.unique_integer([:positive])}", nil)
end

defp uniq_mlir_var(state, val) do
Expand Down
9 changes: 8 additions & 1 deletion lib/charms/defm/pass/create_absent_func.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ defmodule Charms.Defm.Pass.CreateAbsentFunc do
name_str <- MLIR.StringRef.to_string(name),
false <- MapSet.member?(created, name_str) do
mlir ctx: ctx, block: block do
{arg_types, ret_types} =
if s = Beaver.ENIF.signature(ctx, String.to_atom(name_str)) do
s
else
{arg_types, ret_types}
end

Func.func _(
sym_name: "\"#{name_str}\"",
sym_name: MLIR.Attribute.string(name_str),
sym_visibility: MLIR.Attribute.string("private"),
function_type: Type.function(arg_types, ret_types)
) do
Expand Down
2 changes: 1 addition & 1 deletion lib/charms/intrinsic.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defmodule Charms.Intrinsic do
Behaviour to define intrinsic functions.
"""
alias Beaver
@type opt :: {:ctx, MLIR.Context.t()} | {:block, MLIR.Block.t()}
@type opt :: {:ctx, MLIR.Context.t()} | {:block, MLIR.Block.t() | {:loc, MLIR.Location.t()}}
@type opts :: [opt | {atom(), term()}]
@type ir_return :: MLIR.Value.t() | MLIR.Operation.t()
@type intrinsic_return :: ir_return() | (any() -> ir_return())
Expand Down
Loading

0 comments on commit f3355f7

Please sign in to comment.