From 61f500bf7ec1f47e420c35b6fac18162a0080e73 Mon Sep 17 00:00:00 2001 From: tsai Date: Mon, 11 Nov 2024 22:09:50 +0800 Subject: [PATCH 1/4] Expand :|| and :! --- lib/charms/defm/expander.ex | 65 +++++++++++++---- lib/charms/prelude.ex | 18 ++++- test/defm_test.exs | 140 ++++++++++++++++++++++-------------- 3 files changed, 154 insertions(+), 69 deletions(-) diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index 426a9b9..1f26f9b 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -279,7 +279,7 @@ defmodule Charms.Defm.Expander do {l, state, env} = expand(l, state, env) {init, state, env} = expand(init, state, env) result_t = MLIR.Value.type(init) - {list_term_ptr, state} = uniq_mlir_var(state, l) + {list_term_ptr, state} = uniq_mlir_var(l, state) tail_ptr = uniq_mlir_var() head_ptr = uniq_mlir_var() @@ -470,7 +470,7 @@ defmodule Charms.Defm.Expander do defp expand_intrinsics(loc, module, fun, args, state, env) do {args, state, env} = expand(args, state, env) - {params, state} = uniq_mlir_params(state, args) + {params, state} = uniq_mlir_params(args, state) case v = module.handle_intrinsic(fun, params, args, @@ -692,13 +692,21 @@ defmodule Charms.Defm.Expander do {left, state, env} = expand(left, state, env) {right, state, env} = expand(right, state, env) loc = MLIR.Location.from_env(env) - {params, state} = uniq_mlir_params(state, [left, right]) + {params, state} = uniq_mlir_params([left, right], state) - {Charms.Prelude.handle_intrinsic(fun, params, [left, right], - ctx: state.mlir.ctx, - block: state.mlir.blk, - loc: loc - ), state, env} + try do + {Charms.Prelude.handle_intrinsic(fun, params, [left, right], + ctx: state.mlir.ctx, + block: state.mlir.blk, + loc: loc + ), state, env} + rescue + e -> + raise_compile_error( + env, + "Failed to expand prelude intrinsic #{fun}: #{Exception.message(e)}" + ) + end end ## =/2 @@ -1043,6 +1051,19 @@ defmodule Charms.Defm.Expander do v = mlir ctx: state.mlir.ctx, block: state.mlir.blk do + cond_type = MLIR.Value.type(condition) + + # Ensure the condition is a i1, if not compare it to 0 + condition = + if MLIR.equal?(cond_type, Type.i1(ctx: state.mlir.ctx)) do + condition + else + zero = + Arith.constant(value: Attribute.integer(cond_type, 0), loc: loc) >>> cond_type + + Arith.cmpi(condition, zero, predicate: Arith.cmp_i_predicate(:sgt)) >>> Type.i1() + end + b = block _true() do ret_t = @@ -1065,6 +1086,22 @@ defmodule Charms.Defm.Expander do {v, state, env} end + defp expand_macro(_meta, Kernel, :!, [value], _callback, state, env) do + {value, state, env} = expand(value, state, env) + type = MLIR.Value.type(value) + {value, state} = uniq_mlir_var(value, state) + {type, state} = uniq_mlir_var(type, state) + + {not_value, state, env} = + quote do + one = const 1 :: unquote(type) + value arith.xori(unquote(value), one) :: unquote(type) + end + |> expand(state, env) + + {List.last(not_value), state, env} + end + defp expand_macro(_meta, Charms.Defm, :while, [expr, [do: body]], _callback, state, env) do v = mlir ctx: state.mlir.ctx, block: state.mlir.blk do @@ -1246,7 +1283,7 @@ defmodule Charms.Defm.Expander do defp expand_remote(_meta, Kernel, fun, args, state, env) when fun in @prelude_intrinsics do loc = MLIR.Location.from_env(env) {args, state, env} = expand(args, state, env) - {params, state} = uniq_mlir_params(state, args) + {params, state} = uniq_mlir_params(args, state) {Charms.Prelude.handle_intrinsic(fun, params, args, ctx: state.mlir.ctx, @@ -1282,7 +1319,7 @@ defmodule Charms.Defm.Expander do if function_exported?(MLIR.Type, fun, 1) do {apply(MLIR.Type, fun, [[ctx: state.mlir.ctx]]), state, env} else - {params, state} = uniq_mlir_params(state, args) + {params, state} = uniq_mlir_params(args, state) case i = Charms.Prelude.handle_intrinsic(fun, params, args, @@ -1369,7 +1406,7 @@ defmodule Charms.Defm.Expander do defp beam_env_from_defm!(env, state) do if e = state.mlir.enif_env do - uniq_mlir_var(state, e) + uniq_mlir_var(e, state) else raise_compile_error(env, "must be a defm with beam env as the first argument") end @@ -1380,14 +1417,14 @@ defmodule Charms.Defm.Expander do Macro.var(:"#{@var_prefix}#{System.unique_integer([:positive])}", nil) end - defp uniq_mlir_var(state, val) do + defp uniq_mlir_var(val, state) do uniq_mlir_var() |> then(&{&1, put_mlir_var(state, &1, val)}) end - defp uniq_mlir_params(state, args) when is_list(args) do + defp uniq_mlir_params(args, state) when is_list(args) do for param <- args, reduce: {[], state} do {params, %{mlir: _} = state} -> - {param, %{mlir: _} = state} = uniq_mlir_var(state, param) + {param, %{mlir: _} = state} = uniq_mlir_var(param, state) {params ++ [param], state} end end diff --git a/lib/charms/prelude.ex b/lib/charms/prelude.ex index d5a2f40..33ad4c5 100644 --- a/lib/charms/prelude.ex +++ b/lib/charms/prelude.ex @@ -5,12 +5,17 @@ defmodule Charms.Prelude do use Charms.Intrinsic alias Beaver.MLIR.Dialect.{Arith, Func} @enif_functions Beaver.ENIF.functions() - @binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :&&, :*] + @binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :&&, :||, :*] defp constant_of_same_type(i, v, opts) do mlir ctx: opts[:ctx], block: opts[:block] do t = MLIR.CAPI.mlirValueGetType(v) - Arith.constant(value: Attribute.integer(t, i)) >>> t + + if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do + Arith.constant(value: Attribute.integer(t, i)) >>> t + else + raise ArgumentError, "Not an integer type for constant, #{to_string(t)}" + end end end @@ -21,7 +26,11 @@ defmodule Charms.Prelude do i i when is_integer(i) -> - Arith.constant(value: Attribute.integer(t, i)) >>> t + if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do + Arith.constant(value: Attribute.integer(t, i)) >>> t + else + raise ArgumentError, "Not an integer type, #{to_string(t)}" + end end end end @@ -53,6 +62,9 @@ defmodule Charms.Prelude do :&& -> Arith.andi(operands) >>> type + :|| -> + Arith.ori(operands) >>> type + :* -> Arith.muli(operands) >>> type end diff --git a/test/defm_test.exs b/test/defm_test.exs index 8dec714..b49dbca 100644 --- a/test/defm_test.exs +++ b/test/defm_test.exs @@ -1,3 +1,58 @@ +defmodule AddTwoInt do + use Charms, init: false + alias Charms.{Pointer, Term} + + defm add_or_error(env, a, b, error) :: Term.t() do + ptr_a = Pointer.allocate(i64()) + ptr_b = Pointer.allocate(i64()) + + arg_err = + block do + func.return(error) + end + + cond_br enif_get_int64(env, a, ptr_a) != 0 do + cond_br 0 != enif_get_int64(env, b, ptr_b) do + a = Pointer.load(i64(), ptr_a) + b = Pointer.load(i64(), ptr_b) + sum = value llvm.add(a, b) :: i64() + term = enif_make_int64(env, sum) + func.return(term) + else + ^arg_err + end + else + ^arg_err + end + end + + defm add0(env, a, b) :: Term.t() do + ptr_a = Pointer.allocate(i32()) + ptr_b = Pointer.allocate(i32()) + + if enif_get_int(env, a, ptr_a) <= 0 || enif_get_int(env, b, ptr_b) <= 0 do + enif_make_badarg(env) + else + a = Pointer.load(i32(), ptr_a) + b = Pointer.load(i32(), ptr_b) + enif_make_int(env, a + b) + end + end + + defm add(env, a, b) :: Term.t() do + ptr_a = Pointer.allocate(i32()) + ptr_b = Pointer.allocate(i32()) + + if !enif_get_int(env, a, ptr_a) || !enif_get_int(env, b, ptr_b) do + enif_make_badarg(env) + else + a = Pointer.load(i32(), ptr_a) + b = Pointer.load(i32(), ptr_b) + enif_make_int(env, a + b) + end + end +end + defmodule DefmTest do use ExUnit.Case, async: true @@ -6,28 +61,32 @@ defmodule DefmTest do end test "invalid return of absent alias" do - assert_raise CompileError, "test/defm_test.exs:13: invalid return type", fn -> - defmodule InvalidRet do - use Charms - - defm my_function(env, arg1, arg2) :: Invalid.t() do - func.return(arg2) - end - end - end + assert_raise CompileError, + "test/defm_test.exs:#{__ENV__.line + 5}: invalid return type", + fn -> + defmodule InvalidRet do + use Charms + + defm my_function(env, arg1, arg2) :: Invalid.t() do + func.return(arg2) + end + end + end end test "invalid arg of absent alias" do - assert_raise CompileError, "test/defm_test.exs:26: invalid argument type #2", fn -> - defmodule InvalidRet do - use Charms - alias Charms.Term - - defm my_function(env, arg1 :: Pointer.t(), arg2) :: Term.t() do - func.return(arg2) - end - end - end + assert_raise CompileError, + "test/defm_test.exs:#{__ENV__.line + 6}: invalid argument type #2", + fn -> + defmodule InvalidRet do + use Charms + alias Charms.Term + + defm my_function(env, arg1 :: Pointer.t(), arg2) :: Term.t() do + func.return(arg2) + end + end + end end test "only env defm is exported" do @@ -39,41 +98,16 @@ defmodule DefmTest do end test "add two integers" do - defmodule AddTwoInt do - use Charms, init: false - alias Charms.{Pointer, Term} - - defm add(env, a, b, error) :: Term.t() do - ptr_a = Pointer.allocate(i64()) - ptr_b = Pointer.allocate(i64()) - - arg_err = - block do - func.return(error) - end - - cond_br enif_get_int64(env, a, ptr_a) != 0 do - cond_br 0 != enif_get_int64(env, b, ptr_b) do - a = Pointer.load(i64(), ptr_a) - b = Pointer.load(i64(), ptr_b) - sum = value llvm.add(a, b) :: i64() - term = enif_make_int64(env, sum) - func.return(term) - else - ^arg_err - end - else - ^arg_err - end - end - end - assert {:ok, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int) assert {:cached, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int) engine = Charms.JIT.engine(:add_int) assert String.starts_with?(AddTwoInt.__ir__(), "ML\xefR") - assert AddTwoInt.add(1, 2, :arg_err).(engine) == 3 - assert AddTwoInt.add(1, "", :arg_err).(engine) == :arg_err + assert AddTwoInt.add(1, 2).(engine) == 3 + assert_raise ArgumentError, fn -> AddTwoInt.add(1, "2").(engine) end + assert AddTwoInt.add0(1, 2).(engine) == 3 + assert_raise ArgumentError, fn -> AddTwoInt.add0(1, "2").(engine) end + assert AddTwoInt.add_or_error(1, 2, :arg_err).(engine) == 3 + assert AddTwoInt.add_or_error(1, "", :arg_err).(engine) == :arg_err assert :ok = Charms.JIT.destroy(:add_int) end @@ -109,8 +143,10 @@ defmodule DefmTest do end test "undefined remote function" do + line = __ENV__.line + assert_raise CompileError, - "test/defm_test.exs:119: Failed to expand macro Elixir.DifferentCalls.something/1: test/defm_test.exs:119: function something not found in module DifferentCalls", + ~r"Failed to expand macro Elixir.DifferentCalls.something/1.+function something not found in module DifferentCalls", fn -> defmodule Undefined do use Charms @@ -124,7 +160,7 @@ defmodule DefmTest do test "wrong return type remote function" do assert_raise CompileError, - "test/defm_test.exs:133: mismatch type in invocation: f32 vs. i64", + ~r"mismatch type in invocation: f32 vs. i64", fn -> defmodule WrongReturnType do use Charms From f9397961ad9360d94f16f8536331570b0072e364 Mon Sep 17 00:00:00 2001 From: tsai Date: Mon, 11 Nov 2024 22:21:43 +0800 Subject: [PATCH 2/4] address review --- lib/charms/defm/expander.ex | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index 1f26f9b..42a32b4 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -1052,16 +1052,17 @@ defmodule Charms.Defm.Expander do v = mlir ctx: state.mlir.ctx, block: state.mlir.blk do cond_type = MLIR.Value.type(condition) - + bool_type = Type.i1(ctx: state.mlir.ctx) # Ensure the condition is a i1, if not compare it to 0 condition = - if MLIR.equal?(cond_type, Type.i1(ctx: state.mlir.ctx)) do + if MLIR.equal?(cond_type, bool_type) do condition else zero = Arith.constant(value: Attribute.integer(cond_type, 0), loc: loc) >>> cond_type - Arith.cmpi(condition, zero, predicate: Arith.cmp_i_predicate(:sgt)) >>> Type.i1() + Arith.cmpi(condition, zero, predicate: Arith.cmp_i_predicate(:sgt), loc: loc) >>> + Type.i1() end b = From 85d01d1efb12d42a10ae9728c75f04ae38055cb2 Mon Sep 17 00:00:00 2001 From: Shenghang Tsai Date: Tue, 12 Nov 2024 15:35:12 +0800 Subject: [PATCH 3/4] Update defm_test.exs --- test/defm_test.exs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/defm_test.exs b/test/defm_test.exs index b49dbca..f846b46 100644 --- a/test/defm_test.exs +++ b/test/defm_test.exs @@ -3,8 +3,8 @@ defmodule AddTwoInt do alias Charms.{Pointer, Term} defm add_or_error(env, a, b, error) :: Term.t() do - ptr_a = Pointer.allocate(i64()) - ptr_b = Pointer.allocate(i64()) + ptr_a = Pointer.allocate(i32()) + ptr_b = Pointer.allocate(i32()) arg_err = block do @@ -12,11 +12,11 @@ defmodule AddTwoInt do end cond_br enif_get_int64(env, a, ptr_a) != 0 do - cond_br 0 != enif_get_int64(env, b, ptr_b) do - a = Pointer.load(i64(), ptr_a) - b = Pointer.load(i64(), ptr_b) - sum = value llvm.add(a, b) :: i64() - term = enif_make_int64(env, sum) + cond_br 0 != enif_get_int(env, b, ptr_b) do + a = Pointer.load(i32(), ptr_a) + b = Pointer.load(i32(), ptr_b) + sum = value llvm.add(a, b) :: i32() + term = enif_make_int(env, sum) func.return(term) else ^arg_err From 6dc5378c40e64a5d05c0b733639f3040cd328e66 Mon Sep 17 00:00:00 2001 From: Shenghang Tsai Date: Tue, 12 Nov 2024 16:01:59 +0800 Subject: [PATCH 4/4] Update defm_test.exs --- test/defm_test.exs | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/test/defm_test.exs b/test/defm_test.exs index f846b46..ba7391e 100644 --- a/test/defm_test.exs +++ b/test/defm_test.exs @@ -2,7 +2,7 @@ defmodule AddTwoInt do use Charms, init: false alias Charms.{Pointer, Term} - defm add_or_error(env, a, b, error) :: Term.t() do + defm add_or_error_with_cond_br(env, a, b, error) :: Term.t() do ptr_a = Pointer.allocate(i32()) ptr_b = Pointer.allocate(i32()) @@ -11,7 +11,7 @@ defmodule AddTwoInt do func.return(error) end - cond_br enif_get_int64(env, a, ptr_a) != 0 do + cond_br enif_get_int(env, a, ptr_a) != 0 do cond_br 0 != enif_get_int(env, b, ptr_b) do a = Pointer.load(i32(), ptr_a) b = Pointer.load(i32(), ptr_b) @@ -26,19 +26,6 @@ defmodule AddTwoInt do end end - defm add0(env, a, b) :: Term.t() do - ptr_a = Pointer.allocate(i32()) - ptr_b = Pointer.allocate(i32()) - - if enif_get_int(env, a, ptr_a) <= 0 || enif_get_int(env, b, ptr_b) <= 0 do - enif_make_badarg(env) - else - a = Pointer.load(i32(), ptr_a) - b = Pointer.load(i32(), ptr_b) - enif_make_int(env, a + b) - end - end - defm add(env, a, b) :: Term.t() do ptr_a = Pointer.allocate(i32()) ptr_b = Pointer.allocate(i32()) @@ -104,10 +91,8 @@ defmodule DefmTest do assert String.starts_with?(AddTwoInt.__ir__(), "ML\xefR") assert AddTwoInt.add(1, 2).(engine) == 3 assert_raise ArgumentError, fn -> AddTwoInt.add(1, "2").(engine) end - assert AddTwoInt.add0(1, 2).(engine) == 3 - assert_raise ArgumentError, fn -> AddTwoInt.add0(1, "2").(engine) end - assert AddTwoInt.add_or_error(1, 2, :arg_err).(engine) == 3 - assert AddTwoInt.add_or_error(1, "", :arg_err).(engine) == :arg_err + assert AddTwoInt.add_or_error_with_cond_br(1, 2, :arg_err).(engine) == 3 + assert AddTwoInt.add_or_error_with_cond_br(1, "", :arg_err).(engine) == :arg_err assert :ok = Charms.JIT.destroy(:add_int) end