diff --git a/lib/charms/defm.ex b/lib/charms/defm.ex index 30f8459..1ed66a9 100644 --- a/lib/charms/defm.ex +++ b/lib/charms/defm.ex @@ -6,6 +6,7 @@ defmodule Charms.Defm do - use `beaver`'s DSL to define intrinsics which can be called in the function body of a `defm` - use `defm` to define functions that can be JIT-compiled """ + require Beaver.Env use Beaver alias MLIR.Dialect.Func require Func @@ -130,47 +131,36 @@ defmodule Charms.Defm do def compile_definitions(definitions) do import MLIR.Transforms ctx = MLIR.Context.create() - available_ops = MapSet.new(MLIR.Dialect.Registry.ops(:all, ctx: ctx)) - - m = - mlir ctx: ctx do - module do - mlir = %{ - ctx: ctx, - blk: Beaver.Env.block(), - available_ops: available_ops, - vars: Map.new(), - region: nil, - enif_env: nil - } - - for {env, d} <- definitions do - {call, ret_types, body} = d - - ast = - quote do - def(unquote(call) :: unquote(ret_types), unquote(body)) - end - - Charms.Defm.Expander.expand_with_mlir( - ast, - mlir, - env - ) - end + m = MLIR.Module.create(ctx, "") + + mlir ctx: ctx, block: MLIR.Module.body(m) do + mlir = %Charms.Defm.Expander{ + ctx: ctx, + blk: Beaver.Env.block(), + available_ops: MapSet.new(MLIR.Dialect.Registry.ops(:all, ctx: ctx)), + vars: Map.new(), + region: nil, + enif_env: nil, + mod: m + } + + for {env, d} <- definitions do + {call, ret_types, body} = d + + quote do + def(unquote(call) :: unquote(ret_types), unquote(body)) end + |> Charms.Defm.Expander.expand_with(env, mlir) end - |> MLIR.Pass.Composer.nested( - "func.func", - Charms.Defm.Pass.CreateAbsentFunc - ) - |> Charms.Debug.print_ir_pass() - |> canonicalize - |> MLIR.Pass.Composer.run!(print: Charms.Debug.step_print?()) - |> MLIR.to_string(bytecode: true) - - MLIR.Context.destroy(ctx) + end + m + |> MLIR.Pass.Composer.nested("func.func", Charms.Defm.Pass.CreateAbsentFunc) + |> Charms.Debug.print_ir_pass() + |> canonicalize + |> MLIR.Pass.Composer.run!(print: Charms.Debug.step_print?()) + |> MLIR.to_string(bytecode: true) + |> tap(fn _ -> MLIR.Context.destroy(ctx) end) end @doc false diff --git a/lib/charms/defm/expander.ex b/lib/charms/defm/expander.ex index b74cd0e..1fcae07 100644 --- a/lib/charms/defm/expander.ex +++ b/lib/charms/defm/expander.ex @@ -1,11 +1,21 @@ defmodule Charms.Defm.Expander do @moduledoc false + alias Beaver.MLIR.Attribute use Beaver - alias MLIR.Dialect.{Func, CF, SCF} + alias MLIR.Dialect.{Func, CF, SCF, MemRef, Index} require Func # Define the environment we will use for expansion. # We reset the fields below but we will need to set # them accordingly later on. + + defstruct ctx: nil, + mod: nil, + blk: nil, + available_ops: MapSet.new(), + vars: Map.new(), + region: nil, + enif_env: nil + @env %{ Macro.Env.prune_compile_info(__ENV__) | line: 0, @@ -34,9 +44,8 @@ defmodule Charms.Defm.Expander do ctx = MLIR.Context.create() available_ops = MapSet.new(MLIR.Dialect.Registry.ops(:all, ctx: ctx)) - mlir = %{ + mlir = %__MODULE__{ ctx: ctx, - mod: nil, blk: MLIR.Block.create(), available_ops: available_ops, vars: Map.new(), @@ -51,7 +60,7 @@ defmodule Charms.Defm.Expander do ) end - def expand_with_mlir(ast, %{ctx: ctx} = mlir, env) do + def expand_with(ast, env, mlir = %__MODULE__{ctx: ctx}) do available_ops = MapSet.new(MLIR.Dialect.Registry.ops(:all, ctx: ctx)) mlir = mlir |> Map.put(:available_ops, available_ops) @@ -178,8 +187,19 @@ defmodule Charms.Defm.Expander do {while, state, env} end - defp expand_std(module, fun, _args, _state, _env) do - raise ArgumentError, "Unknown standard function: #{inspect(module)}.#{fun}" + defp expand_std(String, :length, args, state, env) do + {string, state, env} = expand(args, state, env) + + mlir ctx: state.mlir.ctx, block: state.mlir.blk do + zero = Index.constant(value: Attribute.index(0)) >>> Type.index() + len = MemRef.dim(string, zero) >>> :infer + end + + {len, state, env} + end + + defp expand_std(_module, _fun, _args, _state, _env) do + :not_implemented end # The goal of this function is to traverse all of Elixir special @@ -400,17 +420,17 @@ defmodule Charms.Defm.Expander do cond do function_exported?(module, :handle_intrinsic, 3) -> - {args, state, env} = args |> expand(state, env) + {args, state, env} = expand(args, state, env) {module.handle_intrinsic(fun, args, ctx: state.mlir.ctx, block: state.mlir.blk), state, env} module == Beaver.MLIR.Attribute -> - {args, state, env} = args |> expand(state, env) + {args, state, env} = expand(args, state, env) {apply(Beaver.MLIR.Attribute, fun, args), state, env} - [module, fun] == [Enum, :reduce] -> - expand_std(Enum, :reduce, args, state, env) + (res = expand_std(module, fun, args, state, env)) != :not_implemented -> + res true -> raise ArgumentError, "Unknown intrinsic: #{inspect(module)}.#{fun}" @@ -512,6 +532,27 @@ defmodule Charms.Defm.Expander do ## Fallback + defp expand(ast, state, env) when is_binary(ast) do + s_table = state.mlir.mod |> MLIR.Operation.from_module() |> MLIR.CAPI.mlirSymbolTableCreate() + sym_name = "__const__" <> :crypto.hash(:sha256, ast) + found = MLIR.CAPI.mlirSymbolTableLookup(s_table, MLIR.StringRef.create(sym_name)) + + mlir ctx: state.mlir.ctx, block: MLIR.Module.body(state.mlir.mod) do + if MLIR.is_null(found) do + MemRef.global(ast, sym_name: Attribute.string(sym_name)) >>> :infer + else + found + end + |> then( + &mlir block: state.mlir.blk do + name = Attribute.flat_symbol_ref(Attribute.unwrap(&1[:sym_name])) + MemRef.get_global(name: name) >>> Attribute.unwrap(&1[:type]) + end + ) + end + |> then(&{&1, state, env}) + end + defp expand(ast, state, env) do {get_mlir_var(state, ast) || ast, state, env} end diff --git a/lib/charms/jit.ex b/lib/charms/jit.ex index f907206..6ff748e 100644 --- a/lib/charms/jit.ex +++ b/lib/charms/jit.ex @@ -65,6 +65,10 @@ defmodule Charms.JIT do def init(module, opts \\ []) + def init({:module, module, binary, _}, opts) when is_atom(module) and is_binary(binary) do + init(module, opts) + end + def init(module, opts) when is_atom(module) do name = opts[:name] || module opts = Keyword.put_new(opts, :name, name) diff --git a/test/string_test.exs b/test/string_test.exs new file mode 100644 index 0000000..afa802e --- /dev/null +++ b/test/string_test.exs @@ -0,0 +1,24 @@ +defmodule StringTest do + use ExUnit.Case + + test "create str" do + defmodule SomeString do + use Charms + alias Charms.{Pointer, Term} + + defm get(env) :: Term.t() do + str = "this is a string" + str = "this is a string" + term_ptr = Pointer.allocate(Term.t()) + d_ptr = enif_make_new_binary(env, String.length(str), term_ptr) + m = ptr_to_memref(d_ptr) + memref.copy(str, m) + t = Pointer.load(Term.t(), term_ptr) + func.return(t) + end + end + |> Charms.JIT.init() + + assert SomeString.get() == "this is a string" + end +end