From 15c3069c8f089dc486d5acaadc1321b262f4a569 Mon Sep 17 00:00:00 2001 From: tsai Date: Tue, 8 Oct 2024 18:33:02 +0800 Subject: [PATCH] simplify examples --- bench/enif_merge_sort.ex | 10 +++--- bench/enif_quick_sort.ex | 15 +++------ bench/enif_tim_sort.ex | 16 +++------ bench/sort_benchmark.exs | 12 +++---- bench/vec_add_int_list.ex | 26 +++++++-------- lib/charms/defm.ex | 23 +++++++++++++ lib/charms/defm/expander.ex | 37 ++++++++++++++------- lib/charms/defm/pass/create_absent_func.ex | 9 ++++- lib/charms/intrinsic.ex | 2 +- lib/charms/jit.ex | 16 +++------ lib/charms/prelude.ex | 2 +- test/defm_test.exs | 11 ++++--- test/mod_merge_test.exs | 14 ++++++++ test/string_test.exs | 10 +++--- test/support/merge_mod.ex | 38 ++++++++++++++++++++++ test/vec_add_test.exs | 3 +- 16 files changed, 160 insertions(+), 84 deletions(-) create mode 100644 test/mod_merge_test.exs create mode 100644 test/support/merge_mod.ex diff --git a/bench/enif_merge_sort.ex b/bench/enif_merge_sort.ex index c830c24..028a656 100644 --- a/bench/enif_merge_sort.ex +++ b/bench/enif_merge_sort.ex @@ -100,10 +100,11 @@ defmodule ENIFMergeSort do func.return end - defm sort(env, list, err) :: Term.t() do + @err %ArgumentError{message: "list expected"} + defm sort(env, list) :: Term.t() do len_ptr = Pointer.allocate(i32()) - cond_br(enif_get_list_length(env, list, len_ptr) != 0) do + if enif_get_list_length(env, list, len_ptr) != 0 do movable_list_ptr = Pointer.allocate(Term.t()) Pointer.store(list, movable_list_ptr) len = Pointer.load(i32(), len_ptr) @@ -111,10 +112,9 @@ defmodule ENIFMergeSort do call ENIFTimSort.copy_terms(env, movable_list_ptr, arr) zero = const 0 :: i32() do_sort(arr, zero, len - 1) - ret = enif_make_list_from_array(env, arr, len) - func.return(ret) + enif_make_list_from_array(env, arr, len) else - func.return(err) + enif_raise_exception(env, @err) end end end diff --git a/bench/enif_quick_sort.ex b/bench/enif_quick_sort.ex index 088edda..d71be2f 100644 --- a/bench/enif_quick_sort.ex +++ b/bench/enif_quick_sort.ex @@ -11,7 +11,6 @@ defmodule ENIFQuickSort do Pointer.store(val_a, b) val_tmp = Pointer.load(Term.t(), tmp) Pointer.store(val_tmp, a) - func.return() end defm partition(arr :: Pointer.t(), low :: i32(), high :: i32()) :: i32() do @@ -41,8 +40,6 @@ defmodule ENIFQuickSort do do_sort(arr, low, pi - 1) do_sort(arr, pi + 1, high) end - - func.return() end defm copy_terms(env :: Env.t(), movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do @@ -65,14 +62,13 @@ defmodule ENIFQuickSort do Pointer.store(head_val, ith_term_ptr) Pointer.store(i + 1, i_ptr) end - - func.return() end - defm sort(env, list, err) :: Term.t() do + @err %ArgumentError{message: "list expected"} + defm sort(env, list) :: Term.t() do len_ptr = Pointer.allocate(i32()) - cond_br(enif_get_list_length(env, list, len_ptr) != 0) do + if enif_get_list_length(env, list, len_ptr) != 0 do movable_list_ptr = Pointer.allocate(Term.t()) Pointer.store(list, movable_list_ptr) len = Pointer.load(i32(), len_ptr) @@ -80,10 +76,9 @@ defmodule ENIFQuickSort do copy_terms(env, movable_list_ptr, arr) zero = const 0 :: i32() do_sort(arr, zero, len - 1) - ret = enif_make_list_from_array(env, arr, len) - func.return(ret) + enif_make_list_from_array(env, arr, len) else - func.return(err) + enif_raise_exception(env, @err) end end end diff --git a/bench/enif_tim_sort.ex b/bench/enif_tim_sort.ex index 88f54f7..3510133 100644 --- a/bench/enif_tim_sort.ex +++ b/bench/enif_tim_sort.ex @@ -32,8 +32,6 @@ defmodule ENIFTimSort do j = Pointer.load(i32(), j_ptr) Pointer.store(temp, Pointer.element_ptr(Term.t(), arr, j + 1)) end - - func.return() end defm tim_sort(arr :: Pointer.t(), n :: i32()) do @@ -73,8 +71,6 @@ defmodule ENIFTimSort do Pointer.store(size * 2, size_ptr) end - - func.return() end defm copy_terms(env :: Env.t(), movable_list_ptr :: Pointer.t(), arr :: Pointer.t()) do @@ -97,24 +93,22 @@ defmodule ENIFTimSort do Pointer.store(head_val, ith_term_ptr) Pointer.store(i + 1, i_ptr) end - - func.return() end - defm sort(env, list, err) :: Term.t() do + @err %ArgumentError{message: "list expected"} + defm sort(env, list) :: Term.t() do len_ptr = Pointer.allocate(i32()) - cond_br(enif_get_list_length(env, list, len_ptr) != 0) do + if enif_get_list_length(env, list, len_ptr) != 0 do movable_list_ptr = Pointer.allocate(Term.t()) Pointer.store(list, movable_list_ptr) len = Pointer.load(i32(), len_ptr) arr = Pointer.allocate(Term.t(), len) copy_terms(env, movable_list_ptr, arr) tim_sort(arr, len) - ret = enif_make_list_from_array(env, arr, len) - func.return(ret) + enif_make_list_from_array(env, arr, len) else - func.return(err) + enif_raise_exception(env, @err) end end end diff --git a/bench/sort_benchmark.exs b/bench/sort_benchmark.exs index 5e207dd..08dd5eb 100644 --- a/bench/sort_benchmark.exs +++ b/bench/sort_benchmark.exs @@ -1,14 +1,14 @@ arr = Enum.to_list(1..10000) |> Enum.shuffle() -ENIFQuickSort.sort(arr, :arg_err) -ENIFMergeSort.sort(arr, :arg_err) -ENIFTimSort.sort(arr, :arg_err) +ENIFQuickSort.sort(arr) +ENIFMergeSort.sort(arr) +ENIFTimSort.sort(arr) Benchee.run( %{ "Enum.sort" => &Enum.sort/1, - "enif_quick_sort" => &ENIFQuickSort.sort(&1, :arg_err), - "enif_merge_sort" => &ENIFMergeSort.sort(&1, :arg_err), - "enif_tim_sort" => &ENIFTimSort.sort(&1, :arg_err) + "enif_quick_sort" => &ENIFQuickSort.sort(&1), + "enif_merge_sort" => &ENIFMergeSort.sort(&1), + "enif_tim_sort" => &ENIFTimSort.sort(&1) }, inputs: %{ "array size 10" => 10, diff --git a/bench/vec_add_int_list.ex b/bench/vec_add_int_list.ex index 7f62907..d4b875c 100644 --- a/bench/vec_add_int_list.ex +++ b/bench/vec_add_int_list.ex @@ -18,7 +18,6 @@ defmodule AddTwoIntVec do Pointer.load(i32(), v_ptr) |> vector.insertelement(acc, i) end) - |> func.return() end defm add(env, a, b, error) :: Term.t() do @@ -27,20 +26,17 @@ defmodule AddTwoIntVec do v = arith.addi(v1, v2) start = const 0 :: i32() - ret = - enif_make_list8( - env, - enif_make_int(env, vector.extractelement(v, start)), - enif_make_int(env, vector.extractelement(v, start + 1)), - enif_make_int(env, vector.extractelement(v, start + 2)), - enif_make_int(env, vector.extractelement(v, start + 3)), - enif_make_int(env, vector.extractelement(v, start + 4)), - enif_make_int(env, vector.extractelement(v, start + 5)), - enif_make_int(env, vector.extractelement(v, start + 6)), - enif_make_int(env, vector.extractelement(v, start + 7)) - ) - - func.return(ret) + enif_make_list8( + env, + enif_make_int(env, vector.extractelement(v, start)), + enif_make_int(env, vector.extractelement(v, start + 1)), + enif_make_int(env, vector.extractelement(v, start + 2)), + enif_make_int(env, vector.extractelement(v, start + 3)), + enif_make_int(env, vector.extractelement(v, start + 4)), + enif_make_int(env, vector.extractelement(v, start + 5)), + enif_make_int(env, vector.extractelement(v, start + 6)), + enif_make_int(env, vector.extractelement(v, start + 7)) + ) end defm dummy_load_no_make(env, a, b, error) :: Term.t() do diff --git a/lib/charms/defm.ex b/lib/charms/defm.ex index 5a04214..7644b9c 100644 --- a/lib/charms/defm.ex +++ b/lib/charms/defm.ex @@ -113,6 +113,25 @@ defmodule Charms.Defm do :ok end + # if it is single block with no terminator, add a return + defp append_missing_return(func) do + with [r] <- Beaver.Walker.regions(func) |> Enum.to_list(), + [b] <- Beaver.Walker.blocks(r) |> Enum.to_list(), + last_op = %MLIR.Operation{} <- + Beaver.Walker.operations(b) |> Enum.to_list() |> List.last(), + false <- MLIR.Operation.name(last_op) == "func.return" do + mlir ctx: MLIR.CAPI.mlirOperationGetContext(last_op), block: b do + results = Beaver.Walker.results(last_op) |> Enum.to_list() + Func.return(results) >>> [] + end + else + _ -> + nil + end + + :ok + end + defp referenced_modules(module) do Beaver.Walker.postwalk(module, MapSet.new(), fn %MLIR.Operation{} = op, acc -> @@ -166,6 +185,10 @@ defmodule Charms.Defm do m |> Charms.Debug.print_ir_pass() + |> MLIR.Pass.Composer.nested( + "func.func", + {"append_missing_return", "func.func", &append_missing_return/1} + ) |> MLIR.Pass.Composer.nested("func.func", Charms.Defm.Pass.CreateAbsentFunc) |> MLIR.Pass.Composer.append({"check-poison", "builtin.module", &check_poison!/1}) |> canonicalize diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index 749ea1a..3496d9e 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -402,10 +402,12 @@ defmodule Charms.Defm.Expander do defp expand({fun, _meta, [left, right]}, state, env) when fun in @intrinsics do {left, state, env} = expand(left, state, env) {right, state, env} = expand(right, state, env) + loc = MLIR.Location.from_env(env) {Charms.Prelude.handle_intrinsic(fun, [left, right], ctx: state.mlir.ctx, - block: state.mlir.blk + block: state.mlir.blk, + loc: loc ), state, env} end @@ -463,6 +465,7 @@ defmodule Charms.Defm.Expander do arity = length(args) mfa = {module, fun, arity} state = update_in(state.remotes, &[mfa | &1]) + loc = MLIR.Location.from_env(env) if is_atom(module) do try do @@ -480,8 +483,11 @@ defmodule Charms.Defm.Expander do function_exported?(module, :__intrinsics__, 0) and fun in module.__intrinsics__() -> {args, state, env} = expand(args, state, env) - {module.handle_intrinsic(fun, args, ctx: state.mlir.ctx, block: state.mlir.blk), - state, env} + {module.handle_intrinsic(fun, args, + ctx: state.mlir.ctx, + block: state.mlir.blk, + loc: loc + ), state, env} module == MLIR.Attribute -> {args, state, env} = expand(args, state, env) @@ -499,8 +505,9 @@ defmodule Charms.Defm.Expander do attr = unquote(attr) term_ptr = Pointer.allocate(Term.t()) size = String.length(attr) + size = value index.casts(size) :: i64() buffer_ptr = Pointer.allocate(i8(), size) - buffer = ptr_to_memref(buffer_ptr) + buffer = ptr_to_memref(buffer_ptr, size) memref.copy(attr, buffer) zero = const 0 :: i32() enif_binary_to_term(unquote(env_var), buffer_ptr, size, term_ptr, zero) @@ -659,21 +666,23 @@ defmodule Charms.Defm.Expander do ## Fallback + @const_prefix "chc" defp expand(ast, state, env) when is_binary(ast) do s_table = state.mlir.mod |> MLIR.Operation.from_module() |> MLIR.CAPI.mlirSymbolTableCreate() - sym_name = "__const__" <> :crypto.hash(:sha256, ast) + sym_name = @const_prefix <> :crypto.hash(:sha256, ast) found = MLIR.CAPI.mlirSymbolTableLookup(s_table, MLIR.StringRef.create(sym_name)) + loc = MLIR.Location.from_env(env) mlir ctx: state.mlir.ctx, block: MLIR.Module.body(state.mlir.mod) do if MLIR.is_null(found) do - MemRef.global(ast, sym_name: Attribute.string(sym_name)) >>> :infer + MemRef.global(ast, sym_name: Attribute.string(sym_name), loc: loc) >>> :infer else found end |> then( &mlir block: state.mlir.blk do name = Attribute.flat_symbol_ref(Attribute.unwrap(&1[:sym_name])) - MemRef.get_global(name: name) >>> Attribute.unwrap(&1[:type]) + MemRef.get_global(name: name, loc: loc) >>> Attribute.unwrap(&1[:type]) end ) end @@ -851,6 +860,7 @@ defmodule Charms.Defm.Expander do true_body = Keyword.fetch!(clauses, :do) false_body = clauses[:else] {condition, state, env} = expand(condition, state, env) + loc = MLIR.Location.from_env(env) v = mlir ctx: state.mlir.ctx, block: state.mlir.blk do @@ -861,7 +871,7 @@ defmodule Charms.Defm.Expander do end # TODO: doc about an expression which is a value and an operation - SCF.if [condition] do + SCF.if [condition, loc: loc] do region do MLIR.CAPI.mlirRegionAppendOwnedBlock(Beaver.Env.region(), b) end @@ -1056,11 +1066,13 @@ defmodule Charms.Defm.Expander do ## Helpers defp expand_remote(_meta, Kernel, fun, args, state, env) when fun in @intrinsics do + loc = MLIR.Location.from_env(env) {args, state, env} = expand(args, state, env) {Charms.Prelude.handle_intrinsic(fun, args, ctx: state.mlir.ctx, - block: state.mlir.blk + block: state.mlir.blk, + loc: loc ), state, env} end @@ -1086,6 +1098,7 @@ defmodule Charms.Defm.Expander do state = update_in(state.locals, &[{fun, length(args)} | &1]) {args, state, env} = expand_list(args, state, env) Code.ensure_loaded!(MLIR.Type) + loc = MLIR.Location.from_env(env) if function_exported?(MLIR.Type, fun, 1) do {apply(MLIR.Type, fun, [[ctx: state.mlir.ctx]]), state, env} @@ -1093,7 +1106,8 @@ defmodule Charms.Defm.Expander do case i = Charms.Prelude.handle_intrinsic(fun, args, ctx: state.mlir.ctx, - block: state.mlir.blk + block: state.mlir.blk, + loc: loc ) do :not_handled -> create_call(env.module, fun, args, [], state, env) @@ -1177,8 +1191,9 @@ defmodule Charms.Defm.Expander do end end + @var_prefix "chv" defp uniq_mlir_var() do - Macro.var(:"chv#{System.unique_integer([:positive])}", nil) + Macro.var(:"#{@var_prefix}#{System.unique_integer([:positive])}", nil) end defp uniq_mlir_var(state, val) do diff --git a/lib/charms/defm/pass/create_absent_func.ex b/lib/charms/defm/pass/create_absent_func.ex index 6fd7e96..0f36a92 100644 --- a/lib/charms/defm/pass/create_absent_func.ex +++ b/lib/charms/defm/pass/create_absent_func.ex @@ -42,8 +42,15 @@ defmodule Charms.Defm.Pass.CreateAbsentFunc do name_str <- MLIR.StringRef.to_string(name), false <- MapSet.member?(created, name_str) do mlir ctx: ctx, block: block do + {arg_types, ret_types} = + if s = Beaver.ENIF.signature(ctx, String.to_atom(name_str)) do + s + else + {arg_types, ret_types} + end + Func.func _( - sym_name: "\"#{name_str}\"", + sym_name: MLIR.Attribute.string(name_str), sym_visibility: MLIR.Attribute.string("private"), function_type: Type.function(arg_types, ret_types) ) do diff --git a/lib/charms/intrinsic.ex b/lib/charms/intrinsic.ex index 5fc32f6..2c35e58 100644 --- a/lib/charms/intrinsic.ex +++ b/lib/charms/intrinsic.ex @@ -3,7 +3,7 @@ defmodule Charms.Intrinsic do Behaviour to define intrinsic functions. """ alias Beaver - @type opt :: {:ctx, MLIR.Context.t()} | {:block, MLIR.Block.t()} + @type opt :: {:ctx, MLIR.Context.t()} | {:block, MLIR.Block.t() | {:loc, MLIR.Location.t()}} @type opts :: [opt | {atom(), term()}] @type ir_return :: MLIR.Value.t() | MLIR.Operation.t() @type intrinsic_return :: ir_return() | (any() -> ir_return()) diff --git a/lib/charms/jit.ex b/lib/charms/jit.ex index 7b51d48..034a9c4 100644 --- a/lib/charms/jit.ex +++ b/lib/charms/jit.ex @@ -25,11 +25,11 @@ defmodule Charms.JIT do |> tap(&beaver_raw_jit_register_enif(&1.ref)) end - defp clone_func_impl(to, from) do + defp clone_ops(to, from) do ops = MLIR.Module.body(from) |> Beaver.Walker.operations() s_table = to |> MLIR.Operation.from_module() |> mlirSymbolTableCreate() - for op <- ops, MLIR.Operation.name(op) == "func.func" do + for op <- ops, MLIR.Operation.name(op) in ~w{func.func memref.global} do sym = mlirOperationGetAttributeByName(op, mlirSymbolTableGetSymbolAttributeName()) found = mlirSymbolTableLookup(s_table, mlirStringAttrGetValue(sym)) body = MLIR.Module.body(to) @@ -52,15 +52,9 @@ defmodule Charms.JIT do [head | tail] = modules for module <- tail do - if MLIR.is_null(module) do - raise "can't merge a null module" - end - - clone_func_impl(head, module) - - if destroy do - MLIR.Module.destroy(module) - end + if MLIR.is_null(module), do: raise("can't merge a null module") + clone_ops(head, module) + if destroy, do: MLIR.Module.destroy(module) end head diff --git a/lib/charms/prelude.ex b/lib/charms/prelude.ex index 1692ccd..656e3c3 100644 --- a/lib/charms/prelude.ex +++ b/lib/charms/prelude.ex @@ -92,7 +92,7 @@ defmodule Charms.Prelude do args = args |> Enum.zip(arg_types) |> Enum.map(&wrap_arg(&1, opts)) mlir ctx: opts[:ctx], block: opts[:block] do - Func.call(args, callee: Attribute.flat_symbol_ref("#{name}")) >>> + Func.call(args, callee: Attribute.flat_symbol_ref("#{name}"), loc: opts[:loc]) >>> case ret_types do [ret] -> ret diff --git a/test/defm_test.exs b/test/defm_test.exs index ba4ce9d..b19c5f4 100644 --- a/test/defm_test.exs +++ b/test/defm_test.exs @@ -44,15 +44,16 @@ defmodule DefmTest do end test "quick sort" do - assert ENIFQuickSort.sort(:what, :arg_err) == :arg_err + assert_raise ArgumentError, "list expected", fn -> ENIFQuickSort.sort(:what) end + arr = [5, 4, 3, 2, 1] - assert ENIFQuickSort.sort(arr, :arg_err) == Enum.sort(arr) + assert ENIFQuickSort.sort(arr) == Enum.sort(arr) for i <- 0..1000 do arr = 0..i |> Enum.shuffle() - assert ENIFTimSort.sort(arr, :arg_err) == Enum.sort(arr) - assert ENIFQuickSort.sort(arr, :arg_err) == Enum.sort(arr) - assert ENIFMergeSort.sort(arr, :arg_err) == Enum.sort(arr) + assert ENIFTimSort.sort(arr) == Enum.sort(arr) + assert ENIFQuickSort.sort(arr) == Enum.sort(arr) + assert ENIFMergeSort.sort(arr) == Enum.sort(arr) end assert :ok = Charms.JIT.destroy(ENIFQuickSort) diff --git a/test/mod_merge_test.exs b/test/mod_merge_test.exs new file mode 100644 index 0000000..b7a5f4d --- /dev/null +++ b/test/mod_merge_test.exs @@ -0,0 +1,14 @@ +defmodule ModMergeTest do + use ExUnit.Case, async: true + + test "attr" do + assert 1 = SubMod0.get_term(1) + assert 1 = SubMod1.get_term(1) + end + + test "func" do + assert SubMod0 = SubMod0.get_attr() + assert SubMod1 = SubMod1.get_attr() + assert SubMod0.get_attr_dup() == SubMod1.get_attr_dup() + end +end diff --git a/test/string_test.exs b/test/string_test.exs index c3ac49c..4a45df0 100644 --- a/test/string_test.exs +++ b/test/string_test.exs @@ -7,14 +7,14 @@ defmodule StringTest do alias Charms.{Pointer, Term} defm get(env) :: Term.t() do - str = "this is a string" + str = "this is a string!" str = "this is a string" term_ptr = Pointer.allocate(Term.t()) - d_ptr = enif_make_new_binary(env, String.length(str), term_ptr) - m = ptr_to_memref(d_ptr) + size = value index.casts(String.length(str)) :: i64() + d_ptr = enif_make_new_binary(env, size, term_ptr) + m = ptr_to_memref(d_ptr, size) memref.copy(str, m) - t = Pointer.load(Term.t(), term_ptr) - func.return(t) + Pointer.load(Term.t(), term_ptr) end end diff --git a/test/support/merge_mod.ex b/test/support/merge_mod.ex new file mode 100644 index 0000000..642c15a --- /dev/null +++ b/test/support/merge_mod.ex @@ -0,0 +1,38 @@ +defmodule SubMod0 do + use Charms + alias Charms.Term + + defm get_term(env, i) :: Term.t() do + func.return(i) + end + + @a __MODULE__ + defm get_attr(env) :: Term.t() do + func.return(@a) + end + + @b :some_attr + defm get_attr_dup(env) :: Term.t() do + func.return(@b) + end +end + +defmodule SubMod1 do + use Charms + alias Charms.Term + + defm get_term(env, i) :: Term.t() do + i = call SubMod0.get_term(env, i) :: Term.t() + func.return(i) + end + + @a __MODULE__ + defm get_attr(env) :: Term.t() do + func.return(@a) + end + + @b :some_attr + defm get_attr_dup(env) :: Term.t() do + func.return(@b) + end +end diff --git a/test/vec_add_test.exs b/test/vec_add_test.exs index 3664cc3..558bb9b 100644 --- a/test/vec_add_test.exs +++ b/test/vec_add_test.exs @@ -14,8 +14,7 @@ defmodule VecAddTest do alias Charms.SIMD defm six(env, a, b, error) do - v1 = SIMD.new(i32(), 8).(1, 1, 1, 1, 1, 1) - func.return() + SIMD.new(i32(), 8).(1, 1, 1, 1, 1, 1) end end end