diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index 426a9b9..42a32b4 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -279,7 +279,7 @@ defmodule Charms.Defm.Expander do {l, state, env} = expand(l, state, env) {init, state, env} = expand(init, state, env) result_t = MLIR.Value.type(init) - {list_term_ptr, state} = uniq_mlir_var(state, l) + {list_term_ptr, state} = uniq_mlir_var(l, state) tail_ptr = uniq_mlir_var() head_ptr = uniq_mlir_var() @@ -470,7 +470,7 @@ defmodule Charms.Defm.Expander do defp expand_intrinsics(loc, module, fun, args, state, env) do {args, state, env} = expand(args, state, env) - {params, state} = uniq_mlir_params(state, args) + {params, state} = uniq_mlir_params(args, state) case v = module.handle_intrinsic(fun, params, args, @@ -692,13 +692,21 @@ defmodule Charms.Defm.Expander do {left, state, env} = expand(left, state, env) {right, state, env} = expand(right, state, env) loc = MLIR.Location.from_env(env) - {params, state} = uniq_mlir_params(state, [left, right]) + {params, state} = uniq_mlir_params([left, right], state) - {Charms.Prelude.handle_intrinsic(fun, params, [left, right], - ctx: state.mlir.ctx, - block: state.mlir.blk, - loc: loc - ), state, env} + try do + {Charms.Prelude.handle_intrinsic(fun, params, [left, right], + ctx: state.mlir.ctx, + block: state.mlir.blk, + loc: loc + ), state, env} + rescue + e -> + raise_compile_error( + env, + "Failed to expand prelude intrinsic #{fun}: #{Exception.message(e)}" + ) + end end ## =/2 @@ -1043,6 +1051,20 @@ defmodule Charms.Defm.Expander do v = mlir ctx: state.mlir.ctx, block: state.mlir.blk do + cond_type = MLIR.Value.type(condition) + bool_type = Type.i1(ctx: state.mlir.ctx) + # Ensure the condition is a i1, if not compare it to 0 + condition = + if MLIR.equal?(cond_type, bool_type) do + condition + else + zero = + Arith.constant(value: Attribute.integer(cond_type, 0), loc: loc) >>> cond_type + + Arith.cmpi(condition, zero, predicate: Arith.cmp_i_predicate(:sgt), loc: loc) >>> + Type.i1() + end + b = block _true() do ret_t = @@ -1065,6 +1087,22 @@ defmodule Charms.Defm.Expander do {v, state, env} end + defp expand_macro(_meta, Kernel, :!, [value], _callback, state, env) do + {value, state, env} = expand(value, state, env) + type = MLIR.Value.type(value) + {value, state} = uniq_mlir_var(value, state) + {type, state} = uniq_mlir_var(type, state) + + {not_value, state, env} = + quote do + one = const 1 :: unquote(type) + value arith.xori(unquote(value), one) :: unquote(type) + end + |> expand(state, env) + + {List.last(not_value), state, env} + end + defp expand_macro(_meta, Charms.Defm, :while, [expr, [do: body]], _callback, state, env) do v = mlir ctx: state.mlir.ctx, block: state.mlir.blk do @@ -1246,7 +1284,7 @@ defmodule Charms.Defm.Expander do defp expand_remote(_meta, Kernel, fun, args, state, env) when fun in @prelude_intrinsics do loc = MLIR.Location.from_env(env) {args, state, env} = expand(args, state, env) - {params, state} = uniq_mlir_params(state, args) + {params, state} = uniq_mlir_params(args, state) {Charms.Prelude.handle_intrinsic(fun, params, args, ctx: state.mlir.ctx, @@ -1282,7 +1320,7 @@ defmodule Charms.Defm.Expander do if function_exported?(MLIR.Type, fun, 1) do {apply(MLIR.Type, fun, [[ctx: state.mlir.ctx]]), state, env} else - {params, state} = uniq_mlir_params(state, args) + {params, state} = uniq_mlir_params(args, state) case i = Charms.Prelude.handle_intrinsic(fun, params, args, @@ -1369,7 +1407,7 @@ defmodule Charms.Defm.Expander do defp beam_env_from_defm!(env, state) do if e = state.mlir.enif_env do - uniq_mlir_var(state, e) + uniq_mlir_var(e, state) else raise_compile_error(env, "must be a defm with beam env as the first argument") end @@ -1380,14 +1418,14 @@ defmodule Charms.Defm.Expander do Macro.var(:"#{@var_prefix}#{System.unique_integer([:positive])}", nil) end - defp uniq_mlir_var(state, val) do + defp uniq_mlir_var(val, state) do uniq_mlir_var() |> then(&{&1, put_mlir_var(state, &1, val)}) end - defp uniq_mlir_params(state, args) when is_list(args) do + defp uniq_mlir_params(args, state) when is_list(args) do for param <- args, reduce: {[], state} do {params, %{mlir: _} = state} -> - {param, %{mlir: _} = state} = uniq_mlir_var(state, param) + {param, %{mlir: _} = state} = uniq_mlir_var(param, state) {params ++ [param], state} end end diff --git a/lib/charms/prelude.ex b/lib/charms/prelude.ex index d5a2f40..33ad4c5 100644 --- a/lib/charms/prelude.ex +++ b/lib/charms/prelude.ex @@ -5,12 +5,17 @@ defmodule Charms.Prelude do use Charms.Intrinsic alias Beaver.MLIR.Dialect.{Arith, Func} @enif_functions Beaver.ENIF.functions() - @binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :&&, :*] + @binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :&&, :||, :*] defp constant_of_same_type(i, v, opts) do mlir ctx: opts[:ctx], block: opts[:block] do t = MLIR.CAPI.mlirValueGetType(v) - Arith.constant(value: Attribute.integer(t, i)) >>> t + + if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do + Arith.constant(value: Attribute.integer(t, i)) >>> t + else + raise ArgumentError, "Not an integer type for constant, #{to_string(t)}" + end end end @@ -21,7 +26,11 @@ defmodule Charms.Prelude do i i when is_integer(i) -> - Arith.constant(value: Attribute.integer(t, i)) >>> t + if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do + Arith.constant(value: Attribute.integer(t, i)) >>> t + else + raise ArgumentError, "Not an integer type, #{to_string(t)}" + end end end end @@ -53,6 +62,9 @@ defmodule Charms.Prelude do :&& -> Arith.andi(operands) >>> type + :|| -> + Arith.ori(operands) >>> type + :* -> Arith.muli(operands) >>> type end diff --git a/test/defm_test.exs b/test/defm_test.exs index 8dec714..ba7391e 100644 --- a/test/defm_test.exs +++ b/test/defm_test.exs @@ -1,3 +1,45 @@ +defmodule AddTwoInt do + use Charms, init: false + alias Charms.{Pointer, Term} + + defm add_or_error_with_cond_br(env, a, b, error) :: Term.t() do + ptr_a = Pointer.allocate(i32()) + ptr_b = Pointer.allocate(i32()) + + arg_err = + block do + func.return(error) + end + + cond_br enif_get_int(env, a, ptr_a) != 0 do + cond_br 0 != enif_get_int(env, b, ptr_b) do + a = Pointer.load(i32(), ptr_a) + b = Pointer.load(i32(), ptr_b) + sum = value llvm.add(a, b) :: i32() + term = enif_make_int(env, sum) + func.return(term) + else + ^arg_err + end + else + ^arg_err + end + end + + defm add(env, a, b) :: Term.t() do + ptr_a = Pointer.allocate(i32()) + ptr_b = Pointer.allocate(i32()) + + if !enif_get_int(env, a, ptr_a) || !enif_get_int(env, b, ptr_b) do + enif_make_badarg(env) + else + a = Pointer.load(i32(), ptr_a) + b = Pointer.load(i32(), ptr_b) + enif_make_int(env, a + b) + end + end +end + defmodule DefmTest do use ExUnit.Case, async: true @@ -6,28 +48,32 @@ defmodule DefmTest do end test "invalid return of absent alias" do - assert_raise CompileError, "test/defm_test.exs:13: invalid return type", fn -> - defmodule InvalidRet do - use Charms - - defm my_function(env, arg1, arg2) :: Invalid.t() do - func.return(arg2) - end - end - end + assert_raise CompileError, + "test/defm_test.exs:#{__ENV__.line + 5}: invalid return type", + fn -> + defmodule InvalidRet do + use Charms + + defm my_function(env, arg1, arg2) :: Invalid.t() do + func.return(arg2) + end + end + end end test "invalid arg of absent alias" do - assert_raise CompileError, "test/defm_test.exs:26: invalid argument type #2", fn -> - defmodule InvalidRet do - use Charms - alias Charms.Term - - defm my_function(env, arg1 :: Pointer.t(), arg2) :: Term.t() do - func.return(arg2) - end - end - end + assert_raise CompileError, + "test/defm_test.exs:#{__ENV__.line + 6}: invalid argument type #2", + fn -> + defmodule InvalidRet do + use Charms + alias Charms.Term + + defm my_function(env, arg1 :: Pointer.t(), arg2) :: Term.t() do + func.return(arg2) + end + end + end end test "only env defm is exported" do @@ -39,41 +85,14 @@ defmodule DefmTest do end test "add two integers" do - defmodule AddTwoInt do - use Charms, init: false - alias Charms.{Pointer, Term} - - defm add(env, a, b, error) :: Term.t() do - ptr_a = Pointer.allocate(i64()) - ptr_b = Pointer.allocate(i64()) - - arg_err = - block do - func.return(error) - end - - cond_br enif_get_int64(env, a, ptr_a) != 0 do - cond_br 0 != enif_get_int64(env, b, ptr_b) do - a = Pointer.load(i64(), ptr_a) - b = Pointer.load(i64(), ptr_b) - sum = value llvm.add(a, b) :: i64() - term = enif_make_int64(env, sum) - func.return(term) - else - ^arg_err - end - else - ^arg_err - end - end - end - assert {:ok, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int) assert {:cached, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int) engine = Charms.JIT.engine(:add_int) assert String.starts_with?(AddTwoInt.__ir__(), "ML\xefR") - assert AddTwoInt.add(1, 2, :arg_err).(engine) == 3 - assert AddTwoInt.add(1, "", :arg_err).(engine) == :arg_err + assert AddTwoInt.add(1, 2).(engine) == 3 + assert_raise ArgumentError, fn -> AddTwoInt.add(1, "2").(engine) end + assert AddTwoInt.add_or_error_with_cond_br(1, 2, :arg_err).(engine) == 3 + assert AddTwoInt.add_or_error_with_cond_br(1, "", :arg_err).(engine) == :arg_err assert :ok = Charms.JIT.destroy(:add_int) end @@ -109,8 +128,10 @@ defmodule DefmTest do end test "undefined remote function" do + line = __ENV__.line + assert_raise CompileError, - "test/defm_test.exs:119: Failed to expand macro Elixir.DifferentCalls.something/1: test/defm_test.exs:119: function something not found in module DifferentCalls", + ~r"Failed to expand macro Elixir.DifferentCalls.something/1.+function something not found in module DifferentCalls", fn -> defmodule Undefined do use Charms @@ -124,7 +145,7 @@ defmodule DefmTest do test "wrong return type remote function" do assert_raise CompileError, - "test/defm_test.exs:133: mismatch type in invocation: f32 vs. i64", + ~r"mismatch type in invocation: f32 vs. i64", fn -> defmodule WrongReturnType do use Charms