Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify examples #37

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bench/enif_merge_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,21 @@ defmodule ENIFMergeSort do
func.return
end

defm sort(env, list, err) :: Term.t() do
@err %ArgumentError{message: "list expected"}
defm sort(env, list) :: Term.t() do
len_ptr = Pointer.allocate(i32())

cond_br(enif_get_list_length(env, list, len_ptr) != 0) do
if enif_get_list_length(env, list, len_ptr) != 0 do
movable_list_ptr = Pointer.allocate(Term.t())
Pointer.store(list, movable_list_ptr)
len = Pointer.load(i32(), len_ptr)
arr = Pointer.allocate(Term.t(), len)
call ENIFTimSort.copy_terms(env, movable_list_ptr, arr)
zero = const 0 :: i32()
do_sort(arr, zero, len - 1)
ret = enif_make_list_from_array(env, arr, len)
func.return(ret)
enif_make_list_from_array(env, arr, len)
else
func.return(err)
enif_raise_exception(env, @err)
end
end
end
15 changes: 5 additions & 10 deletions bench/enif_quick_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ defmodule ENIFQuickSort do
Pointer.store(val_a, b)
val_tmp = Pointer.load(Term.t(), tmp)
Pointer.store(val_tmp, a)
func.return()
end

defm partition(arr :: Pointer.t(), low :: i32(), high :: i32()) :: i32() do
Expand Down Expand Up @@ -41,8 +40,6 @@ defmodule ENIFQuickSort do
do_sort(arr, low, pi - 1)
do_sort(arr, pi + 1, high)
end

func.return()
end

defm copy_terms(env :: Env.t(), movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do
Expand All @@ -65,25 +62,23 @@ defmodule ENIFQuickSort do
Pointer.store(head_val, ith_term_ptr)
Pointer.store(i + 1, i_ptr)
end

func.return()
end

defm sort(env, list, err) :: Term.t() do
@err %ArgumentError{message: "list expected"}
defm sort(env, list) :: Term.t() do
len_ptr = Pointer.allocate(i32())

cond_br(enif_get_list_length(env, list, len_ptr) != 0) do
if enif_get_list_length(env, list, len_ptr) != 0 do
movable_list_ptr = Pointer.allocate(Term.t())
Pointer.store(list, movable_list_ptr)
len = Pointer.load(i32(), len_ptr)
arr = Pointer.allocate(Term.t(), len)
copy_terms(env, movable_list_ptr, arr)
zero = const 0 :: i32()
do_sort(arr, zero, len - 1)
ret = enif_make_list_from_array(env, arr, len)
func.return(ret)
enif_make_list_from_array(env, arr, len)
else
func.return(err)
enif_raise_exception(env, @err)
end
end
end
16 changes: 5 additions & 11 deletions bench/enif_tim_sort.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ defmodule ENIFTimSort do
j = Pointer.load(i32(), j_ptr)
Pointer.store(temp, Pointer.element_ptr(Term.t(), arr, j + 1))
end

func.return()
end

defm tim_sort(arr :: Pointer.t(), n :: i32()) do
Expand Down Expand Up @@ -73,8 +71,6 @@ defmodule ENIFTimSort do

Pointer.store(size * 2, size_ptr)
end

func.return()
end

defm copy_terms(env :: Env.t(), movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do
Expand All @@ -97,24 +93,22 @@ defmodule ENIFTimSort do
Pointer.store(head_val, ith_term_ptr)
Pointer.store(i + 1, i_ptr)
end

func.return()
end

defm sort(env, list, err) :: Term.t() do
@err %ArgumentError{message: "list expected"}
defm sort(env, list) :: Term.t() do
len_ptr = Pointer.allocate(i32())

cond_br(enif_get_list_length(env, list, len_ptr) != 0) do
if enif_get_list_length(env, list, len_ptr) != 0 do
movable_list_ptr = Pointer.allocate(Term.t())
Pointer.store(list, movable_list_ptr)
len = Pointer.load(i32(), len_ptr)
arr = Pointer.allocate(Term.t(), len)
copy_terms(env, movable_list_ptr, arr)
tim_sort(arr, len)
ret = enif_make_list_from_array(env, arr, len)
func.return(ret)
enif_make_list_from_array(env, arr, len)
else
func.return(err)
enif_raise_exception(env, @err)
end
end
end
12 changes: 6 additions & 6 deletions bench/sort_benchmark.exs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
arr = Enum.to_list(1..10000) |> Enum.shuffle()
ENIFQuickSort.sort(arr, :arg_err)
ENIFMergeSort.sort(arr, :arg_err)
ENIFTimSort.sort(arr, :arg_err)
ENIFQuickSort.sort(arr)
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
ENIFMergeSort.sort(arr)
ENIFTimSort.sort(arr)

Benchee.run(
%{
"Enum.sort" => &Enum.sort/1,
"enif_quick_sort" => &ENIFQuickSort.sort(&1, :arg_err),
"enif_merge_sort" => &ENIFMergeSort.sort(&1, :arg_err),
"enif_tim_sort" => &ENIFTimSort.sort(&1, :arg_err)
"enif_quick_sort" => &ENIFQuickSort.sort(&1),
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
"enif_merge_sort" => &ENIFMergeSort.sort(&1),
"enif_tim_sort" => &ENIFTimSort.sort(&1)
},
inputs: %{
"array size 10" => 10,
Expand Down
26 changes: 11 additions & 15 deletions bench/vec_add_int_list.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ defmodule AddTwoIntVec do
Pointer.load(i32(), v_ptr)
|> vector.insertelement(acc, i)
end)
|> func.return()
end

defm add(env, a, b, error) :: Term.t() do
Expand All @@ -27,20 +26,17 @@ defmodule AddTwoIntVec do
v = arith.addi(v1, v2)
start = const 0 :: i32()

ret =
enif_make_list8(
env,
enif_make_int(env, vector.extractelement(v, start)),
enif_make_int(env, vector.extractelement(v, start + 1)),
enif_make_int(env, vector.extractelement(v, start + 2)),
enif_make_int(env, vector.extractelement(v, start + 3)),
enif_make_int(env, vector.extractelement(v, start + 4)),
enif_make_int(env, vector.extractelement(v, start + 5)),
enif_make_int(env, vector.extractelement(v, start + 6)),
enif_make_int(env, vector.extractelement(v, start + 7))
)

func.return(ret)
enif_make_list8(
env,
enif_make_int(env, vector.extractelement(v, start)),
enif_make_int(env, vector.extractelement(v, start + 1)),
enif_make_int(env, vector.extractelement(v, start + 2)),
enif_make_int(env, vector.extractelement(v, start + 3)),
enif_make_int(env, vector.extractelement(v, start + 4)),
enif_make_int(env, vector.extractelement(v, start + 5)),
enif_make_int(env, vector.extractelement(v, start + 6)),
enif_make_int(env, vector.extractelement(v, start + 7))
)
end

defm dummy_load_no_make(env, a, b, error) :: Term.t() do
Expand Down
23 changes: 23 additions & 0 deletions lib/charms/defm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,25 @@ defmodule Charms.Defm do
:ok
end

# if it is single block with no terminator, add a return
defp append_missing_return(func) do
with [r] <- Beaver.Walker.regions(func) |> Enum.to_list(),
[b] <- Beaver.Walker.blocks(r) |> Enum.to_list(),
last_op = %MLIR.Operation{} <-
Beaver.Walker.operations(b) |> Enum.to_list() |> List.last(),
false <- MLIR.Operation.name(last_op) == "func.return" do
mlir ctx: MLIR.CAPI.mlirOperationGetContext(last_op), block: b do
results = Beaver.Walker.results(last_op) |> Enum.to_list()
Func.return(results) >>> []
end
else
_ ->
nil
end

:ok
end

defp referenced_modules(module) do
Beaver.Walker.postwalk(module, MapSet.new(), fn
%MLIR.Operation{} = op, acc ->
Expand Down Expand Up @@ -166,6 +185,10 @@ defmodule Charms.Defm do

m
|> Charms.Debug.print_ir_pass()
|> MLIR.Pass.Composer.nested(
"func.func",
{"append_missing_return", "func.func", &append_missing_return/1}
)
|> MLIR.Pass.Composer.nested("func.func", Charms.Defm.Pass.CreateAbsentFunc)
|> MLIR.Pass.Composer.append({"check-poison", "builtin.module", &check_poison!/1})
|> canonicalize
Expand Down
37 changes: 26 additions & 11 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,12 @@ defmodule Charms.Defm.Expander do
defp expand({fun, _meta, [left, right]}, state, env) when fun in @intrinsics do
{left, state, env} = expand(left, state, env)
{right, state, env} = expand(right, state, env)
loc = MLIR.Location.from_env(env)

{Charms.Prelude.handle_intrinsic(fun, [left, right],
ctx: state.mlir.ctx,
block: state.mlir.blk
block: state.mlir.blk,
loc: loc
), state, env}
end

Expand Down Expand Up @@ -463,6 +465,7 @@ defmodule Charms.Defm.Expander do
arity = length(args)
mfa = {module, fun, arity}
state = update_in(state.remotes, &[mfa | &1])
loc = MLIR.Location.from_env(env)

if is_atom(module) do
try do
Expand All @@ -480,8 +483,11 @@ defmodule Charms.Defm.Expander do
function_exported?(module, :__intrinsics__, 0) and fun in module.__intrinsics__() ->
{args, state, env} = expand(args, state, env)

{module.handle_intrinsic(fun, args, ctx: state.mlir.ctx, block: state.mlir.blk),
state, env}
{module.handle_intrinsic(fun, args,
ctx: state.mlir.ctx,
block: state.mlir.blk,
loc: loc
), state, env}

module == MLIR.Attribute ->
{args, state, env} = expand(args, state, env)
Expand All @@ -499,8 +505,9 @@ defmodule Charms.Defm.Expander do
attr = unquote(attr)
term_ptr = Pointer.allocate(Term.t())
size = String.length(attr)
size = value index.casts(size) :: i64()
buffer_ptr = Pointer.allocate(i8(), size)
buffer = ptr_to_memref(buffer_ptr)
buffer = ptr_to_memref(buffer_ptr, size)
memref.copy(attr, buffer)
zero = const 0 :: i32()
enif_binary_to_term(unquote(env_var), buffer_ptr, size, term_ptr, zero)
Expand Down Expand Up @@ -659,21 +666,23 @@ defmodule Charms.Defm.Expander do

## Fallback

@const_prefix "chc"
defp expand(ast, state, env) when is_binary(ast) do
s_table = state.mlir.mod |> MLIR.Operation.from_module() |> MLIR.CAPI.mlirSymbolTableCreate()
sym_name = "__const__" <> :crypto.hash(:sha256, ast)
sym_name = @const_prefix <> :crypto.hash(:sha256, ast)
found = MLIR.CAPI.mlirSymbolTableLookup(s_table, MLIR.StringRef.create(sym_name))
loc = MLIR.Location.from_env(env)

mlir ctx: state.mlir.ctx, block: MLIR.Module.body(state.mlir.mod) do
if MLIR.is_null(found) do
MemRef.global(ast, sym_name: Attribute.string(sym_name)) >>> :infer
MemRef.global(ast, sym_name: Attribute.string(sym_name), loc: loc) >>> :infer
else
found
end
|> then(
&mlir block: state.mlir.blk do
name = Attribute.flat_symbol_ref(Attribute.unwrap(&1[:sym_name]))
MemRef.get_global(name: name) >>> Attribute.unwrap(&1[:type])
MemRef.get_global(name: name, loc: loc) >>> Attribute.unwrap(&1[:type])
end
)
end
Expand Down Expand Up @@ -851,6 +860,7 @@ defmodule Charms.Defm.Expander do
true_body = Keyword.fetch!(clauses, :do)
false_body = clauses[:else]
{condition, state, env} = expand(condition, state, env)
loc = MLIR.Location.from_env(env)

v =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
Expand All @@ -861,7 +871,7 @@ defmodule Charms.Defm.Expander do
end

# TODO: doc about an expression which is a value and an operation
SCF.if [condition] do
SCF.if [condition, loc: loc] do
region do
MLIR.CAPI.mlirRegionAppendOwnedBlock(Beaver.Env.region(), b)
end
Expand Down Expand Up @@ -1056,11 +1066,13 @@ defmodule Charms.Defm.Expander do
## Helpers

defp expand_remote(_meta, Kernel, fun, args, state, env) when fun in @intrinsics do
loc = MLIR.Location.from_env(env)
{args, state, env} = expand(args, state, env)

{Charms.Prelude.handle_intrinsic(fun, args,
ctx: state.mlir.ctx,
block: state.mlir.blk
block: state.mlir.blk,
loc: loc
), state, env}
end

Expand All @@ -1086,14 +1098,16 @@ defmodule Charms.Defm.Expander do
state = update_in(state.locals, &[{fun, length(args)} | &1])
{args, state, env} = expand_list(args, state, env)
Code.ensure_loaded!(MLIR.Type)
loc = MLIR.Location.from_env(env)

if function_exported?(MLIR.Type, fun, 1) do
{apply(MLIR.Type, fun, [[ctx: state.mlir.ctx]]), state, env}
else
case i =
Charms.Prelude.handle_intrinsic(fun, args,
ctx: state.mlir.ctx,
block: state.mlir.blk
block: state.mlir.blk,
loc: loc
) do
:not_handled ->
create_call(env.module, fun, args, [], state, env)
Expand Down Expand Up @@ -1177,8 +1191,9 @@ defmodule Charms.Defm.Expander do
end
end

@var_prefix "chv"
defp uniq_mlir_var() do
Macro.var(:"chv#{System.unique_integer([:positive])}", nil)
Macro.var(:"#{@var_prefix}#{System.unique_integer([:positive])}", nil)
end

defp uniq_mlir_var(state, val) do
Expand Down
9 changes: 8 additions & 1 deletion lib/charms/defm/pass/create_absent_func.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ defmodule Charms.Defm.Pass.CreateAbsentFunc do
name_str <- MLIR.StringRef.to_string(name),
false <- MapSet.member?(created, name_str) do
mlir ctx: ctx, block: block do
{arg_types, ret_types} =
if s = Beaver.ENIF.signature(ctx, String.to_atom(name_str)) do
s
else
{arg_types, ret_types}
end

Func.func _(
sym_name: "\"#{name_str}\"",
sym_name: MLIR.Attribute.string(name_str),
sym_visibility: MLIR.Attribute.string("private"),
function_type: Type.function(arg_types, ret_types)
) do
Expand Down
2 changes: 1 addition & 1 deletion lib/charms/intrinsic.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defmodule Charms.Intrinsic do
Behaviour to define intrinsic functions.
"""
alias Beaver
@type opt :: {:ctx, MLIR.Context.t()} | {:block, MLIR.Block.t()}
@type opt :: {:ctx, MLIR.Context.t()} | {:block, MLIR.Block.t() | {:loc, MLIR.Location.t()}}
@type opts :: [opt | {atom(), term()}]
@type ir_return :: MLIR.Value.t() | MLIR.Operation.t()
@type intrinsic_return :: ir_return() | (any() -> ir_return())
Expand Down
Loading
Loading