From db646e33c3b76a1c8f51300b49c0faecfad0d3bd Mon Sep 17 00:00:00 2001 From: tsai Date: Tue, 20 Aug 2024 07:28:23 +0800 Subject: [PATCH 1/4] rename struct_if -> if --- bench/enif_merge_sort.ex | 4 ++-- bench/enif_quick_sort.ex | 4 ++-- bench/enif_tim_sort.ex | 4 ++-- lib/charms/defm.ex | 5 ---- lib/charms/defm/expander.ex | 46 +++++++++++++++++++++++++++++-------- lib/charms/prelude.ex | 3 ++- test/if_test.exs | 32 ++++++++++++++++++++++++++ 7 files changed, 76 insertions(+), 22 deletions(-) create mode 100644 test/if_test.exs diff --git a/bench/enif_merge_sort.ex b/bench/enif_merge_sort.ex index 92cb711..94a9a44 100644 --- a/bench/enif_merge_sort.ex +++ b/bench/enif_merge_sort.ex @@ -40,7 +40,7 @@ defmodule ENIFMergeSort do left_term = Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)) right_term = Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), right_temp, j)) - struct_if(enif_compare(left_term, right_term) <= 0) do + if(enif_compare(left_term, right_term) <= 0) do Pointer.store( Pointer.load(Term.t(), Pointer.element_ptr(Term.t(), left_temp, i)), Pointer.element_ptr(Term.t(), arr, k) @@ -89,7 +89,7 @@ defmodule ENIFMergeSort do end defm do_sort(arr :: Pointer.t(), l :: i32(), r :: i32()) do - struct_if(l < r) do + if(l < r) do two_const = op arith.constant(value: Attribute.integer(i32(), 2)) :: i32() two = result_at(two_const, 0) m = op arith.divsi(l + r, two) :: i32() diff --git a/bench/enif_quick_sort.ex b/bench/enif_quick_sort.ex index 65bba8d..5d1b9ea 100644 --- a/bench/enif_quick_sort.ex +++ b/bench/enif_quick_sort.ex @@ -22,7 +22,7 @@ defmodule ENIFQuickSort do start = Pointer.element_ptr(Term.t(), arr, low) for_loop {element, j} <- {Term.t(), start, high - low} do - struct_if(enif_compare(element, pivot) < 0) do + if(enif_compare(element, pivot) < 0) do i = Pointer.load(i32(), i_ptr) + 1 Pointer.store(i, i_ptr) j = value index.casts(j) :: i32() @@ -40,7 +40,7 @@ defmodule ENIFQuickSort do end defm do_sort(arr :: Pointer.t(), low :: i32(), high :: i32()) do - struct_if(low < high) do + if(low < high) do pi = call partition(arr, low, high) :: i32() do_sort(arr, low, pi - 1) do_sort(arr, pi + 1, high) diff --git a/bench/enif_tim_sort.ex b/bench/enif_tim_sort.ex index 0afb854..6960ab6 100644 --- a/bench/enif_tim_sort.ex +++ b/bench/enif_tim_sort.ex @@ -66,8 +66,8 @@ defmodule ENIFTimSort do right = op arith.minsi(left + 2 * size - 1, n - 1) :: i32() right = result_at(right, 0) - struct_if(mid < right) do - call ENIFMergeSort.merge(arr, left, mid, right) :: [] + if(mid < right) do + call ENIFMergeSort.merge(arr, left, mid, right) end Pointer.store(left + 2 * size, left_ptr) diff --git a/lib/charms/defm.ex b/lib/charms/defm.ex index 1ed66a9..8dd41c9 100644 --- a/lib/charms/defm.ex +++ b/lib/charms/defm.ex @@ -46,11 +46,6 @@ defmodule Charms.Defm do """ defmacro cond_br(_condition, _clauses), do: :implemented_in_expander - @doc """ - `if` expression requires identical types for both branches - """ - defmacro struct_if(_condition, _clauses), do: :implemented_in_expander - @doc """ define a function that can be JIT compiled diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index 1fcae07..d8cd5d2 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -711,7 +711,7 @@ defmodule Charms.Defm.Expander do {v, state, env} end - defp expand_macro(_meta, Charms.Defm, :struct_if, [condition, clauses], _callback, state, env) do + defp expand_macro(_meta, Kernel, :if, [condition, clauses], _callback, state, env) do true_body = Keyword.fetch!(clauses, :do) false_body = clauses[:else] {condition, state, env} = expand(condition, state, env) @@ -720,24 +720,50 @@ defmodule Charms.Defm.Expander do mlir ctx: state.mlir.ctx, block: state.mlir.blk do alias Beaver.MLIR.Dialect.SCF + b = + block _true() do + ret_t = + with {ret, _, _} <- + expand(true_body, put_in(state.mlir.blk, Beaver.Env.block()), env), + %MLIR.Value{} = ret when not is_nil(false_body) <- + ret |> List.wrap() |> List.last() do + SCF.yield(ret) >>> [] + MLIR.Value.type(ret) + else + %MLIR.Operation{} -> + SCF.yield() >>> [] + [] + + %MLIR.Value{} = ret -> + SCF.yield(ret) >>> [] + MLIR.Value.type(ret) + end + end + + # TODO: doc about an expression which is a value and an operation SCF.if [condition] do region do - block _true() do - expand(true_body, put_in(state.mlir.blk, Beaver.Env.block()), env) - SCF.yield() >>> [] - end + MLIR.CAPI.mlirRegionAppendOwnedBlock(Beaver.Env.region(), b) end region do block _false() do - if false_body do - expand(false_body, put_in(state.mlir.blk, Beaver.Env.block()), env) + with {ret, _, _} <- + unless(is_nil(false_body), + do: expand(false_body, put_in(state.mlir.blk, Beaver.Env.block()), env) + ), + %MLIR.Value{} = ret <- ret |> List.wrap() |> List.last() do + SCF.yield(ret) >>> [] + else + %MLIR.Value{} = ret -> + SCF.yield(ret) >>> [] + + _ -> + SCF.yield() >>> [] end - - SCF.yield() >>> [] end end - end >>> [] + end >>> ret_t end {v, state, env} diff --git a/lib/charms/prelude.ex b/lib/charms/prelude.ex index 8b349da..6fe3722 100644 --- a/lib/charms/prelude.ex +++ b/lib/charms/prelude.ex @@ -8,7 +8,8 @@ defmodule Charms.Prelude do @enif_functions ++ [:result_at] ++ @binary_ops end - defp constant_of_same_type(i, v, opts) do + @doc false + def constant_of_same_type(i, v, opts) do mlir ctx: opts[:ctx], block: opts[:block] do t = MLIR.CAPI.mlirValueGetType(v) Arith.constant(value: Attribute.integer(t, i)) >>> t diff --git a/test/if_test.exs b/test/if_test.exs new file mode 100644 index 0000000..d29e582 --- /dev/null +++ b/test/if_test.exs @@ -0,0 +1,32 @@ +defmodule IfTest do + use ExUnit.Case + + test "if with value" do + defmodule GetIntIf do + use Charms + alias Charms.{Pointer, Term} + + defm get(env, i) :: Term.t() do + zero = arith.constant(value: Attribute.integer(i32(), 0)) + one = arith.constant(value: Attribute.integer(i32(), 1)) + i_ptr = Pointer.allocate(i32()) + enif_get_int(env, i, i_ptr) + i = Pointer.load(i32(), i_ptr) + + ret = + if(i > 0) do + one + else + zero + end + + ret = enif_make_int(env, ret) + func.return(ret) + end + end + |> Charms.JIT.init() + + assert GetIntIf.get(100) == 1 + assert GetIntIf.get(-100) == 0 + end +end From 2930879e1e81046bda67ac7eb2e20ac3f7fc6ce1 Mon Sep 17 00:00:00 2001 From: tsai Date: Tue, 20 Aug 2024 07:37:40 +0800 Subject: [PATCH 2/4] Update prelude.ex --- lib/charms/prelude.ex | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/charms/prelude.ex b/lib/charms/prelude.ex index 6fe3722..8b349da 100644 --- a/lib/charms/prelude.ex +++ b/lib/charms/prelude.ex @@ -8,8 +8,7 @@ defmodule Charms.Prelude do @enif_functions ++ [:result_at] ++ @binary_ops end - @doc false - def constant_of_same_type(i, v, opts) do + defp constant_of_same_type(i, v, opts) do mlir ctx: opts[:ctx], block: opts[:block] do t = MLIR.CAPI.mlirValueGetType(v) Arith.constant(value: Attribute.integer(t, i)) >>> t From 840e2f38b72865c4b93deb77e3343fcc4d51b20b Mon Sep 17 00:00:00 2001 From: tsai Date: Tue, 20 Aug 2024 08:01:14 +0800 Subject: [PATCH 3/4] Update expander.ex --- lib/charms/defm/expander.ex | 53 +++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index d8cd5d2..ed45ee3 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -570,6 +570,29 @@ defmodule Charms.Defm.Expander do Module.concat(mod, func) end + defp expand_if_clause_body(nil, state, _env) do + mlir ctx: state.mlir.ctx, block: state.mlir.blk do + SCF.yield() >>> [] + [] + end + end + + defp expand_if_clause_body(clause_body, state, env) do + mlir ctx: state.mlir.ctx, block: state.mlir.blk do + {ret, _, _} = expand(clause_body, state, env) + + case ret |> List.wrap() |> List.last() do + %MLIR.Operation{} -> + SCF.yield() >>> [] + [] + + %MLIR.Value{} = last -> + SCF.yield(last) >>> [] + MLIR.Value.type(last) + end + end + end + ## Macro handling # This is going to be the function where you will intercept expansions @@ -723,21 +746,7 @@ defmodule Charms.Defm.Expander do b = block _true() do ret_t = - with {ret, _, _} <- - expand(true_body, put_in(state.mlir.blk, Beaver.Env.block()), env), - %MLIR.Value{} = ret when not is_nil(false_body) <- - ret |> List.wrap() |> List.last() do - SCF.yield(ret) >>> [] - MLIR.Value.type(ret) - else - %MLIR.Operation{} -> - SCF.yield() >>> [] - [] - - %MLIR.Value{} = ret -> - SCF.yield(ret) >>> [] - MLIR.Value.type(ret) - end + expand_if_clause_body(true_body, put_in(state.mlir.blk, Beaver.Env.block()), env) end # TODO: doc about an expression which is a value and an operation @@ -748,19 +757,7 @@ defmodule Charms.Defm.Expander do region do block _false() do - with {ret, _, _} <- - unless(is_nil(false_body), - do: expand(false_body, put_in(state.mlir.blk, Beaver.Env.block()), env) - ), - %MLIR.Value{} = ret <- ret |> List.wrap() |> List.last() do - SCF.yield(ret) >>> [] - else - %MLIR.Value{} = ret -> - SCF.yield(ret) >>> [] - - _ -> - SCF.yield() >>> [] - end + expand_if_clause_body(false_body, put_in(state.mlir.blk, Beaver.Env.block()), env) end end end >>> ret_t From 78df799065f98cbf1f1c4d7e07d11e31a51f090a Mon Sep 17 00:00:00 2001 From: tsai Date: Tue, 20 Aug 2024 08:20:06 +0800 Subject: [PATCH 4/4] Update expander.ex --- lib/charms/defm/expander.ex | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index ed45ee3..0a774e7 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -570,6 +570,7 @@ defmodule Charms.Defm.Expander do Module.concat(mod, func) end + # Expands a nil clause body in an if statement, yielding no value. defp expand_if_clause_body(nil, state, _env) do mlir ctx: state.mlir.ctx, block: state.mlir.blk do SCF.yield() >>> [] @@ -577,6 +578,7 @@ defmodule Charms.Defm.Expander do end end + # Expands a non-nil clause body in an if statement, yielding the last evaluated value. defp expand_if_clause_body(clause_body, state, env) do mlir ctx: state.mlir.ctx, block: state.mlir.blk do {ret, _, _} = expand(clause_body, state, env) @@ -734,6 +736,7 @@ defmodule Charms.Defm.Expander do {v, state, env} end + # Expands an `if` expression, handling both true and false clause bodies. defp expand_macro(_meta, Kernel, :if, [condition, clauses], _callback, state, env) do true_body = Keyword.fetch!(clauses, :do) false_body = clauses[:else]