Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand :|| and :! #47

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 52 additions & 14 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ defmodule Charms.Defm.Expander do
{l, state, env} = expand(l, state, env)
{init, state, env} = expand(init, state, env)
result_t = MLIR.Value.type(init)
{list_term_ptr, state} = uniq_mlir_var(state, l)
{list_term_ptr, state} = uniq_mlir_var(l, state)
tail_ptr = uniq_mlir_var()
head_ptr = uniq_mlir_var()

Expand Down Expand Up @@ -470,7 +470,7 @@ defmodule Charms.Defm.Expander do

defp expand_intrinsics(loc, module, fun, args, state, env) do
{args, state, env} = expand(args, state, env)
{params, state} = uniq_mlir_params(state, args)
{params, state} = uniq_mlir_params(args, state)

case v =
module.handle_intrinsic(fun, params, args,
Expand Down Expand Up @@ -692,13 +692,21 @@ defmodule Charms.Defm.Expander do
{left, state, env} = expand(left, state, env)
{right, state, env} = expand(right, state, env)
loc = MLIR.Location.from_env(env)
{params, state} = uniq_mlir_params(state, [left, right])
{params, state} = uniq_mlir_params([left, right], state)

{Charms.Prelude.handle_intrinsic(fun, params, [left, right],
ctx: state.mlir.ctx,
block: state.mlir.blk,
loc: loc
), state, env}
try do
{Charms.Prelude.handle_intrinsic(fun, params, [left, right],
ctx: state.mlir.ctx,
block: state.mlir.blk,
loc: loc
), state, env}
rescue
e ->
raise_compile_error(
env,
"Failed to expand prelude intrinsic #{fun}: #{Exception.message(e)}"
)
end
end

## =/2
Expand Down Expand Up @@ -1043,6 +1051,20 @@ defmodule Charms.Defm.Expander do

v =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
cond_type = MLIR.Value.type(condition)
bool_type = Type.i1(ctx: state.mlir.ctx)
# Ensure the condition is a i1, if not compare it to 0
condition =
if MLIR.equal?(cond_type, bool_type) do
condition
else
zero =
Arith.constant(value: Attribute.integer(cond_type, 0), loc: loc) >>> cond_type

Arith.cmpi(condition, zero, predicate: Arith.cmp_i_predicate(:sgt), loc: loc) >>>
Type.i1()
end

b =
block _true() do
ret_t =
Expand All @@ -1065,6 +1087,22 @@ defmodule Charms.Defm.Expander do
{v, state, env}
end

defp expand_macro(_meta, Kernel, :!, [value], _callback, state, env) do
{value, state, env} = expand(value, state, env)
type = MLIR.Value.type(value)
{value, state} = uniq_mlir_var(value, state)
{type, state} = uniq_mlir_var(type, state)

{not_value, state, env} =
quote do
one = const 1 :: unquote(type)
value arith.xori(unquote(value), one) :: unquote(type)
end
|> expand(state, env)

{List.last(not_value), state, env}
end

jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
defp expand_macro(_meta, Charms.Defm, :while, [expr, [do: body]], _callback, state, env) do
v =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
Expand Down Expand Up @@ -1246,7 +1284,7 @@ defmodule Charms.Defm.Expander do
defp expand_remote(_meta, Kernel, fun, args, state, env) when fun in @prelude_intrinsics do
loc = MLIR.Location.from_env(env)
{args, state, env} = expand(args, state, env)
{params, state} = uniq_mlir_params(state, args)
{params, state} = uniq_mlir_params(args, state)

{Charms.Prelude.handle_intrinsic(fun, params, args,
ctx: state.mlir.ctx,
Expand Down Expand Up @@ -1282,7 +1320,7 @@ defmodule Charms.Defm.Expander do
if function_exported?(MLIR.Type, fun, 1) do
{apply(MLIR.Type, fun, [[ctx: state.mlir.ctx]]), state, env}
else
{params, state} = uniq_mlir_params(state, args)
{params, state} = uniq_mlir_params(args, state)

case i =
Charms.Prelude.handle_intrinsic(fun, params, args,
Expand Down Expand Up @@ -1369,7 +1407,7 @@ defmodule Charms.Defm.Expander do

defp beam_env_from_defm!(env, state) do
if e = state.mlir.enif_env do
uniq_mlir_var(state, e)
uniq_mlir_var(e, state)
else
raise_compile_error(env, "must be a defm with beam env as the first argument")
end
Expand All @@ -1380,14 +1418,14 @@ defmodule Charms.Defm.Expander do
Macro.var(:"#{@var_prefix}#{System.unique_integer([:positive])}", nil)
end

defp uniq_mlir_var(state, val) do
defp uniq_mlir_var(val, state) do
uniq_mlir_var() |> then(&{&1, put_mlir_var(state, &1, val)})
end

defp uniq_mlir_params(state, args) when is_list(args) do
defp uniq_mlir_params(args, state) when is_list(args) do
for param <- args, reduce: {[], state} do
{params, %{mlir: _} = state} ->
{param, %{mlir: _} = state} = uniq_mlir_var(state, param)
{param, %{mlir: _} = state} = uniq_mlir_var(param, state)
{params ++ [param], state}
end
end
Expand Down
18 changes: 15 additions & 3 deletions lib/charms/prelude.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ defmodule Charms.Prelude do
use Charms.Intrinsic
alias Beaver.MLIR.Dialect.{Arith, Func}
@enif_functions Beaver.ENIF.functions()
@binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :&&, :*]
@binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :&&, :||, :*]

defp constant_of_same_type(i, v, opts) do
mlir ctx: opts[:ctx], block: opts[:block] do
t = MLIR.CAPI.mlirValueGetType(v)
Arith.constant(value: Attribute.integer(t, i)) >>> t

if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do
Arith.constant(value: Attribute.integer(t, i)) >>> t
else
raise ArgumentError, "Not an integer type for constant, #{to_string(t)}"
end
end
end

Expand All @@ -21,7 +26,11 @@ defmodule Charms.Prelude do
i

i when is_integer(i) ->
Arith.constant(value: Attribute.integer(t, i)) >>> t
if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do
Arith.constant(value: Attribute.integer(t, i)) >>> t
else
raise ArgumentError, "Not an integer type, #{to_string(t)}"
end
end
end
end
Expand Down Expand Up @@ -53,6 +62,9 @@ defmodule Charms.Prelude do
:&& ->
Arith.andi(operands) >>> type

:|| ->
Arith.ori(operands) >>> type

:* ->
Arith.muli(operands) >>> type
end
Expand Down
140 changes: 88 additions & 52 deletions test/defm_test.exs
Original file line number Diff line number Diff line change
@@ -1,3 +1,58 @@
defmodule AddTwoInt do
use Charms, init: false
alias Charms.{Pointer, Term}

defm add_or_error(env, a, b, error) :: Term.t() do
ptr_a = Pointer.allocate(i32())
ptr_b = Pointer.allocate(i32())

arg_err =
block do
func.return(error)
end

cond_br enif_get_int64(env, a, ptr_a) != 0 do
cond_br 0 != enif_get_int(env, b, ptr_b) do
a = Pointer.load(i32(), ptr_a)
b = Pointer.load(i32(), ptr_b)
sum = value llvm.add(a, b) :: i32()
term = enif_make_int(env, sum)
func.return(term)
else
^arg_err
end
else
^arg_err
end
end
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved

defm add0(env, a, b) :: Term.t() do
ptr_a = Pointer.allocate(i32())
ptr_b = Pointer.allocate(i32())

if enif_get_int(env, a, ptr_a) <= 0 || enif_get_int(env, b, ptr_b) <= 0 do
enif_make_badarg(env)
else
a = Pointer.load(i32(), ptr_a)
b = Pointer.load(i32(), ptr_b)
enif_make_int(env, a + b)
end
end
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved

defm add(env, a, b) :: Term.t() do
ptr_a = Pointer.allocate(i32())
ptr_b = Pointer.allocate(i32())

if !enif_get_int(env, a, ptr_a) || !enif_get_int(env, b, ptr_b) do
enif_make_badarg(env)
else
a = Pointer.load(i32(), ptr_a)
b = Pointer.load(i32(), ptr_b)
enif_make_int(env, a + b)
end
end
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
end

defmodule DefmTest do
use ExUnit.Case, async: true

Expand All @@ -6,28 +61,32 @@ defmodule DefmTest do
end

test "invalid return of absent alias" do
assert_raise CompileError, "test/defm_test.exs:13: invalid return type", fn ->
defmodule InvalidRet do
use Charms

defm my_function(env, arg1, arg2) :: Invalid.t() do
func.return(arg2)
end
end
end
assert_raise CompileError,
"test/defm_test.exs:#{__ENV__.line + 5}: invalid return type",
fn ->
defmodule InvalidRet do
use Charms

defm my_function(env, arg1, arg2) :: Invalid.t() do
func.return(arg2)
end
end
end
end

test "invalid arg of absent alias" do
assert_raise CompileError, "test/defm_test.exs:26: invalid argument type #2", fn ->
defmodule InvalidRet do
use Charms
alias Charms.Term

defm my_function(env, arg1 :: Pointer.t(), arg2) :: Term.t() do
func.return(arg2)
end
end
end
assert_raise CompileError,
"test/defm_test.exs:#{__ENV__.line + 6}: invalid argument type #2",
fn ->
defmodule InvalidRet do
use Charms
alias Charms.Term

defm my_function(env, arg1 :: Pointer.t(), arg2) :: Term.t() do
func.return(arg2)
end
end
end
end

test "only env defm is exported" do
Expand All @@ -39,41 +98,16 @@ defmodule DefmTest do
end

test "add two integers" do
defmodule AddTwoInt do
use Charms, init: false
alias Charms.{Pointer, Term}

defm add(env, a, b, error) :: Term.t() do
ptr_a = Pointer.allocate(i64())
ptr_b = Pointer.allocate(i64())

arg_err =
block do
func.return(error)
end

cond_br enif_get_int64(env, a, ptr_a) != 0 do
cond_br 0 != enif_get_int64(env, b, ptr_b) do
a = Pointer.load(i64(), ptr_a)
b = Pointer.load(i64(), ptr_b)
sum = value llvm.add(a, b) :: i64()
term = enif_make_int64(env, sum)
func.return(term)
else
^arg_err
end
else
^arg_err
end
end
end

assert {:ok, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int)
assert {:cached, %Charms.JIT{}} = Charms.JIT.init(AddTwoInt, name: :add_int)
engine = Charms.JIT.engine(:add_int)
assert String.starts_with?(AddTwoInt.__ir__(), "ML\xefR")
assert AddTwoInt.add(1, 2, :arg_err).(engine) == 3
assert AddTwoInt.add(1, "", :arg_err).(engine) == :arg_err
assert AddTwoInt.add(1, 2).(engine) == 3
assert_raise ArgumentError, fn -> AddTwoInt.add(1, "2").(engine) end
assert AddTwoInt.add0(1, 2).(engine) == 3
assert_raise ArgumentError, fn -> AddTwoInt.add0(1, "2").(engine) end
assert AddTwoInt.add_or_error(1, 2, :arg_err).(engine) == 3
assert AddTwoInt.add_or_error(1, "", :arg_err).(engine) == :arg_err
assert :ok = Charms.JIT.destroy(:add_int)
end

Expand Down Expand Up @@ -109,8 +143,10 @@ defmodule DefmTest do
end

test "undefined remote function" do
line = __ENV__.line

assert_raise CompileError,
"test/defm_test.exs:119: Failed to expand macro Elixir.DifferentCalls.something/1: test/defm_test.exs:119: function something not found in module DifferentCalls",
~r"Failed to expand macro Elixir.DifferentCalls.something/1.+function something not found in module DifferentCalls",
fn ->
defmodule Undefined do
use Charms
Expand All @@ -124,7 +160,7 @@ defmodule DefmTest do

test "wrong return type remote function" do
assert_raise CompileError,
"test/defm_test.exs:133: mismatch type in invocation: f32 vs. i64",
~r"mismatch type in invocation: f32 vs. i64",
fn ->
defmodule WrongReturnType do
use Charms
Expand Down
Loading