Skip to content

Commit

Permalink
Vector add example (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper authored Jul 9, 2024
1 parent 21476ed commit c04abfd
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 21 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/elixir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ jobs:
run: mix deps.get
- name: Run tests
run: mix test
- name: Benchmark add
run: |
mix run bench/list_add.exs
- name: Benchmark sort
run: |
mix run bench/sort.exs
25 changes: 25 additions & 0 deletions bench/list_add.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
mod = AddTwoIntVec
Charms.JIT.init(mod)

Benchee.run(
%{
"Enum.zip_reduce" => fn {a, b} ->
Enum.zip_reduce(a, b, [], fn x, y, acc -> [x + y | acc] end) |> Enum.reverse()
end,
"AddTwoIntVec.add" => fn {a, b} -> AddTwoIntVec.add(a, b, :err_msg) end,
"AddTwoIntVec.dummy_load_no_make" => fn {a, b} ->
AddTwoIntVec.dummy_load_no_make(a, b, :err_msg)
end,
"AddTwoIntVec.dummy_return" => fn {a, b} -> AddTwoIntVec.dummy_return(a, b, :err_msg) end
},
inputs: %{
"array size 8" => 8
},
before_scenario: fn i ->
a = Enum.to_list(1..i) |> Enum.shuffle()
b = Enum.to_list(1..i) |> Enum.shuffle()
{a, b}
end
)

Charms.JIT.destroy(mod)
20 changes: 14 additions & 6 deletions bench/sort.exs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,20 @@ Benchee.run(
"enif_merge_sort" => fn arr -> ENIFMergeSort.sort(arr, :arg_err) end,
"enif_tim_sort" => fn arr -> ENIFTimSort.sort(arr, :arg_err) end
},
inputs: %{
"array size 10" => 10,
"array size 100" => 100,
"array size 1000" => 1000,
"array size 67_000" => 67_000
},
inputs:
%{
"array size 10" => 10,
"array size 100" => 100,
"array size 1000" => 1000
}
|> then(fn m ->
# for some reason, it segfaults on linux with large array size
if :os.type() == {:unix, :darwin} do
Map.merge(m, %{"array size 65535" => 65535})
else
m
end
end),
before_scenario: fn i ->
Enum.to_list(1..i) |> Enum.shuffle()
end
Expand Down
53 changes: 53 additions & 0 deletions bench/vec_add_int_list.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
defmodule AddTwoIntVec do
use Charms
alias Charms.{SIMD, Term, Pointer}

defm load_list(env, l :: Term.t()) :: SIMD.t(i32(), 8) do
i_ptr = Pointer.allocate(i32())
Pointer.store(arith.constant(value: Attribute.integer(i32(), 0)), i_ptr)
init = SIMD.new(i32(), 8).(0, 0, 0, 0, 0, 0, 0, 0)

Enum.reduce(l, init, fn x, acc ->
v_ptr = Pointer.allocate(i32())
enif_get_int(env, x, v_ptr)
i = Pointer.load(i32(), i_ptr)
Pointer.store(i + 1, i_ptr)

Pointer.load(i32(), v_ptr)
|> vector.insertelement(acc, i)
end)
|> func.return()
end

defm add(env, a, b, error) :: Term.t() do
v1 = call load_list(env, a) :: SIMD.t(i32(), 8)
v2 = call load_list(env, b) :: SIMD.t(i32(), 8)
v = arith.addi(v1, v2)
start = arith.constant(value: Attribute.integer(i32(), 0))

ret =
enif_make_list8(
env,
enif_make_int(env, vector.extractelement(v, start)),
enif_make_int(env, vector.extractelement(v, start + 1)),
enif_make_int(env, vector.extractelement(v, start + 2)),
enif_make_int(env, vector.extractelement(v, start + 3)),
enif_make_int(env, vector.extractelement(v, start + 4)),
enif_make_int(env, vector.extractelement(v, start + 5)),
enif_make_int(env, vector.extractelement(v, start + 6)),
enif_make_int(env, vector.extractelement(v, start + 7))
)

func.return(ret)
end

defm dummy_load_no_make(env, a, b, error) :: Term.t() do
v1 = call load_list(env, a) :: SIMD.t(i32(), 8)
v2 = call load_list(env, b) :: SIMD.t(i32(), 8)
func.return(a)
end

defm dummy_return(env, a, b, error) :: Term.t() do
func.return(a)
end
end
16 changes: 16 additions & 0 deletions lib/charms/debug.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
defmodule Charms.Debug do
@moduledoc false
alias Beaver.MLIR

def print_ir_pass(op) do
if System.get_env("DEFM_PRINT_IR") == "1" do
MLIR.Transforms.print_ir(op)
else
op
end
end

def step_print?() do
System.get_env("DEFM_PRINT_IR") == "step"
end
end
6 changes: 4 additions & 2 deletions lib/charms/defm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ defmodule Charms.Defm do
blk: Beaver.Env.block(),
available_ops: available_ops,
vars: Map.new(),
region: nil
region: nil,
enif_env: nil
}

for {env, d} <- definitions do
Expand All @@ -153,8 +154,9 @@ defmodule Charms.Defm do
"func.func",
Charms.Defm.Pass.CreateAbsentFunc
)
|> Charms.Debug.print_ir_pass()
|> canonicalize
|> MLIR.Pass.Composer.run!(print: System.get_env("DEFM_PRINT_IR") == "1")
|> MLIR.Pass.Composer.run!(print: Charms.Debug.step_print?())
|> MLIR.to_string(bytecode: true)

MLIR.Context.destroy(ctx)
Expand Down
149 changes: 140 additions & 9 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
defmodule Charms.Defm.Expander do
@moduledoc false
require Logger
use Beaver
alias MLIR.Dialect.{Func, CF}
alias MLIR.Dialect.{Func, CF, SCF}
require Func
# Define the environment we will use for expansion.
# We reset the fields below but we will need to set
Expand Down Expand Up @@ -41,7 +40,8 @@ defmodule Charms.Defm.Expander do
blk: MLIR.Block.create(),
available_ops: available_ops,
vars: Map.new(),
region: nil
region: nil,
enif_env: nil
}

expand(
Expand Down Expand Up @@ -96,6 +96,92 @@ defmodule Charms.Defm.Expander do
create_call(mod, name, args, types, state, env)
end

defp has_implemented_inference(op, ctx) when is_bitstring(op) do
id = MLIR.CAPI.mlirInferTypeOpInterfaceTypeID()

op
|> MLIR.StringRef.create()
|> MLIR.CAPI.mlirOperationImplementsInterfaceStatic(ctx, id)
|> Beaver.Native.to_term()
end

defp expand_std(Enum, :reduce, args, state, env) do
while =
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
[l, init, f] = args
{l, state, env} = expand(l, state, env)
{init, state, env} = expand(init, state, env)
result_t = MLIR.Value.type(init)
state = put_mlir_var(state, :charms_internal_list, l)

{_, state, env} =
quote do
charms_internal_tail_ptr = Charms.Pointer.allocate(Term.t())
Pointer.store(charms_internal_list, charms_internal_tail_ptr)
charms_internal_head_ptr = Charms.Pointer.allocate(Term.t())
end
|> expand(state, env)

# we compile the Enum.reduce/3 to a scf.while in MLIR
SCF.while [init] do
region do
block _(acc >>> result_t) do
state = put_in(state.mlir.blk, Beaver.Env.block())

# getting the BEAM env, assuming it is a regular defm with env as the first argument
state =
if e = state.mlir.enif_env do
put_mlir_var(state, :charms_internal_env, e)
else
raise ArgumentError, "No enif_env found"
end

# the condition of the while loop, consuming the list with enif_get_list_cell
{condition, _state, _env} =
quote do
enif_get_list_cell(
charms_internal_env,
Pointer.load(Term.t(), charms_internal_tail_ptr),
charms_internal_head_ptr,
charms_internal_tail_ptr
) > 0
end
|> expand(state, env)

SCF.condition(condition, acc) >>> []
end
end

# the body of the while loop, compiled from the reducer which is an anonymous function
region do
block _(acc >>> result_t) do
state = put_in(state.mlir.blk, Beaver.Env.block())
{:fn, _, [{:->, _, [[arg_element, arg_acc], body]}]} = f

# inject head and acc before expanding the body
state = put_mlir_var(state, arg_acc, acc)

{head_val, state, env} =
quote(do: Charms.Pointer.load(Charms.Term.t(), charms_internal_head_ptr))
|> expand(state, env)

state = put_mlir_var(state, arg_element, head_val)

# expand the body
{body, _state, _env} = expand(body, state, env)
SCF.yield(List.last(body)) >>> []
end
end
end >>> result_t
end

{while, state, env}
end

defp expand_std(module, fun, _args, _state, _env) do
raise ArgumentError, "Unknown standard function: #{inspect(module)}.#{fun}"
end

# The goal of this function is to traverse all of Elixir special
# forms. The list is actually relatively small and a good reference
# is the Elixir type checker: https://github.com/elixir-lang/elixir/blob/494a018abbc88901747c32032ec9e2c408f40608/lib/elixir/lib/module/types/expr.ex
Expand Down Expand Up @@ -323,6 +409,9 @@ defmodule Charms.Defm.Expander do
{args, state, env} = args |> expand(state, env)
{apply(Beaver.MLIR.Attribute, fun, args), state, env}

[module, fun] == [Enum, :reduce] ->
expand_std(Enum, :reduce, args, state, env)

true ->
raise ArgumentError, "Unknown intrinsic: #{inspect(module)}.#{fun}"
end
Expand All @@ -335,19 +424,36 @@ defmodule Charms.Defm.Expander do
raise ArgumentError,
"Unknown MLIR operation to create: #{op}, did you mean: #{did_you_mean_op(op)}"

{args, state, env} = expand_list(args, state, env)
{args, state, env} = expand(args, state, env)

op =
%Beaver.SSA{
op: op,
arguments: args,
ctx: state.mlir.ctx,
block: state.mlir.blk,
loc: Beaver.MLIR.Location.from_env(env)
loc: Beaver.MLIR.Location.from_env(env),
results: if(has_implemented_inference(op, state.mlir.ctx), do: [:infer], else: [])
}
|> MLIR.Operation.create()

{{MLIR.Operation.results(op), meta, args}, state, env}
{MLIR.Operation.results(op), state, env}
end
end

# Parameterized function call
defp expand(
{{:., _parameterized_meta, [parameterized]}, _meta, args},
state,
env
) do
{args, state, env} = expand(args, state, env)
{parameterized, state, env} = expand(parameterized, state, env)

if is_function(parameterized) do
{parameterized.(args), state, env}
else
raise ArgumentError, "Expected a function, got: #{inspect(parameterized)}"
end
end

Expand Down Expand Up @@ -482,6 +588,7 @@ defmodule Charms.Defm.Expander do
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
{ret_types, state, env} = ret_types |> expand(state, env)
{arg_types, state, env} = arg_types |> expand(state, env)

ft = Type.function(arg_types, ret_types, ctx: Beaver.Env.context())

Func.func _(sym_name: "\"#{name}\"", function_type: ft) do
Expand All @@ -499,6 +606,17 @@ defmodule Charms.Defm.Expander do
Enum.zip(args, arg_values)
|> Enum.reduce(state, fn {k, v}, state -> put_mlir_var(state, k, v) end)

state =
with [head_arg_type | _] <- arg_types,
[head_arg | _] <- args,
{:env, _, nil} <- head_arg,
MLIR.Type.equal?(head_arg_type, Beaver.ENIF.Type.env(ctx: state.mlir.ctx)) do
a = MLIR.Block.get_arg!(Beaver.Env.block(), 0)
put_in(state.mlir.enif_env, a)
else
_ -> state
end

state = put_in(state.mlir.blk, Beaver.Env.block())
expand(body, state, env)
end
Expand Down Expand Up @@ -652,7 +770,16 @@ defmodule Charms.Defm.Expander do
{{dialect, _, _}, op, args} = Macro.decompose_call(call)
op = "#{dialect}.#{op}"
{args, state, env} = expand(args, state, env)
{[return_types], state, env} = expand(return_types, state, env)
{return_types, state, env} = expand(return_types, state, env)

return_types =
case return_types do
[] ->
[:infer]

_ ->
List.flatten(return_types)
end

op =
%Beaver.SSA{
Expand Down Expand Up @@ -815,11 +942,15 @@ defmodule Charms.Defm.Expander do
{ast, state, env}
end

defp put_mlir_var(state, {name, _meta, _ctx} = _ast, val) do
defp put_mlir_var(state, name, val) when is_atom(name) do
update_in(state.mlir.vars, &Map.put(&1, name, val))
end

defp get_mlir_var(state, {name, _meta, _ctx} = _ast) do
defp put_mlir_var(state, {name, _meta, _ctx}, val) do
put_mlir_var(state, name, val)
end

defp get_mlir_var(state, {name, _meta, _ctx}) do
Map.get(state.mlir.vars, name)
end

Expand Down
4 changes: 3 additions & 1 deletion lib/charms/jit.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ defmodule Charms.JIT do
|> convert_arith_to_llvm()
|> convert_index_to_llvm()
|> convert_func_to_llvm()
|> MLIR.Pass.Composer.append("convert-vector-to-llvm{reassociate-fp-reductions}")
|> MLIR.Pass.Composer.append("finalize-memref-to-llvm")
|> reconcile_unrealized_casts
|> MLIR.Pass.Composer.run!(print: System.get_env("DEFM_PRINT_IR") == "1")
|> Charms.Debug.print_ir_pass()
|> MLIR.Pass.Composer.run!(print: Charms.Debug.step_print?())
|> MLIR.ExecutionEngine.create!(opt_level: 3, object_dump: true)
|> tap(&beaver_raw_jit_register_enif(&1.ref))
end
Expand Down
Loading

0 comments on commit c04abfd

Please sign in to comment.