Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile binary literals in expander #12

Merged
merged 2 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 28 additions & 38 deletions lib/charms/defm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ defmodule Charms.Defm do
- use `beaver`'s DSL to define intrinsics which can be called in the function body of a `defm`
- use `defm` to define functions that can be JIT-compiled
"""
require Beaver.Env
use Beaver
alias MLIR.Dialect.Func
require Func
Expand Down Expand Up @@ -130,47 +131,36 @@ defmodule Charms.Defm do
def compile_definitions(definitions) do
import MLIR.Transforms
ctx = MLIR.Context.create()
available_ops = MapSet.new(MLIR.Dialect.Registry.ops(:all, ctx: ctx))

m =
mlir ctx: ctx do
module do
mlir = %{
ctx: ctx,
blk: Beaver.Env.block(),
available_ops: available_ops,
vars: Map.new(),
region: nil,
enif_env: nil
}

for {env, d} <- definitions do
{call, ret_types, body} = d

ast =
quote do
def(unquote(call) :: unquote(ret_types), unquote(body))
end

Charms.Defm.Expander.expand_with_mlir(
ast,
mlir,
env
)
end
m = MLIR.Module.create(ctx, "")

mlir ctx: ctx, block: MLIR.Module.body(m) do
mlir = %Charms.Defm.Expander{
ctx: ctx,
blk: Beaver.Env.block(),
available_ops: MapSet.new(MLIR.Dialect.Registry.ops(:all, ctx: ctx)),
vars: Map.new(),
region: nil,
enif_env: nil,
mod: m
}

for {env, d} <- definitions do
{call, ret_types, body} = d

quote do
def(unquote(call) :: unquote(ret_types), unquote(body))
end
|> Charms.Defm.Expander.expand_with(env, mlir)
end
|> MLIR.Pass.Composer.nested(
"func.func",
Charms.Defm.Pass.CreateAbsentFunc
)
|> Charms.Debug.print_ir_pass()
|> canonicalize
|> MLIR.Pass.Composer.run!(print: Charms.Debug.step_print?())
|> MLIR.to_string(bytecode: true)

MLIR.Context.destroy(ctx)
end

m
|> MLIR.Pass.Composer.nested("func.func", Charms.Defm.Pass.CreateAbsentFunc)
|> Charms.Debug.print_ir_pass()
|> canonicalize
|> MLIR.Pass.Composer.run!(print: Charms.Debug.step_print?())
|> MLIR.to_string(bytecode: true)
|> tap(fn _ -> MLIR.Context.destroy(ctx) end)
end

@doc false
Expand Down
61 changes: 51 additions & 10 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
defmodule Charms.Defm.Expander do
@moduledoc false
alias Beaver.MLIR.Attribute
use Beaver
alias MLIR.Dialect.{Func, CF, SCF}
alias MLIR.Dialect.{Func, CF, SCF, MemRef, Index}
require Func
# Define the environment we will use for expansion.
# We reset the fields below but we will need to set
# them accordingly later on.

defstruct ctx: nil,
mod: nil,
blk: nil,
available_ops: MapSet.new(),
vars: Map.new(),
region: nil,
enif_env: nil

@env %{
Macro.Env.prune_compile_info(__ENV__)
| line: 0,
Expand Down Expand Up @@ -34,9 +44,8 @@ defmodule Charms.Defm.Expander do
ctx = MLIR.Context.create()
available_ops = MapSet.new(MLIR.Dialect.Registry.ops(:all, ctx: ctx))

mlir = %{
mlir = %__MODULE__{
ctx: ctx,
mod: nil,
blk: MLIR.Block.create(),
available_ops: available_ops,
vars: Map.new(),
Expand All @@ -51,7 +60,7 @@ defmodule Charms.Defm.Expander do
)
end

def expand_with_mlir(ast, %{ctx: ctx} = mlir, env) do
def expand_with(ast, env, mlir = %__MODULE__{ctx: ctx}) do
available_ops = MapSet.new(MLIR.Dialect.Registry.ops(:all, ctx: ctx))
mlir = mlir |> Map.put(:available_ops, available_ops)

Expand Down Expand Up @@ -178,8 +187,19 @@ defmodule Charms.Defm.Expander do
{while, state, env}
end

defp expand_std(module, fun, _args, _state, _env) do
raise ArgumentError, "Unknown standard function: #{inspect(module)}.#{fun}"
defp expand_std(String, :length, args, state, env) do
{string, state, env} = expand(args, state, env)

mlir ctx: state.mlir.ctx, block: state.mlir.blk do
zero = Index.constant(value: Attribute.index(0)) >>> Type.index()
len = MemRef.dim(string, zero) >>> :infer
end

{len, state, env}
end

defp expand_std(_module, _fun, _args, _state, _env) do
:not_implemented
end

# The goal of this function is to traverse all of Elixir special
Expand Down Expand Up @@ -400,17 +420,17 @@ defmodule Charms.Defm.Expander do

cond do
function_exported?(module, :handle_intrinsic, 3) ->
{args, state, env} = args |> expand(state, env)
{args, state, env} = expand(args, state, env)

{module.handle_intrinsic(fun, args, ctx: state.mlir.ctx, block: state.mlir.blk),
state, env}

module == Beaver.MLIR.Attribute ->
{args, state, env} = args |> expand(state, env)
{args, state, env} = expand(args, state, env)
{apply(Beaver.MLIR.Attribute, fun, args), state, env}

[module, fun] == [Enum, :reduce] ->
expand_std(Enum, :reduce, args, state, env)
(res = expand_std(module, fun, args, state, env)) != :not_implemented ->
res

true ->
raise ArgumentError, "Unknown intrinsic: #{inspect(module)}.#{fun}"
Expand Down Expand Up @@ -512,6 +532,27 @@ defmodule Charms.Defm.Expander do

## Fallback

defp expand(ast, state, env) when is_binary(ast) do
s_table = state.mlir.mod |> MLIR.Operation.from_module() |> MLIR.CAPI.mlirSymbolTableCreate()
sym_name = "__const__" <> :crypto.hash(:sha256, ast)
found = MLIR.CAPI.mlirSymbolTableLookup(s_table, MLIR.StringRef.create(sym_name))

mlir ctx: state.mlir.ctx, block: MLIR.Module.body(state.mlir.mod) do
if MLIR.is_null(found) do
MemRef.global(ast, sym_name: Attribute.string(sym_name)) >>> :infer
else
found
end
|> then(
&mlir block: state.mlir.blk do
name = Attribute.flat_symbol_ref(Attribute.unwrap(&1[:sym_name]))
MemRef.get_global(name: name) >>> Attribute.unwrap(&1[:type])
end
)
end
|> then(&{&1, state, env})
end

defp expand(ast, state, env) do
{get_mlir_var(state, ast) || ast, state, env}
end
Expand Down
4 changes: 4 additions & 0 deletions lib/charms/jit.ex
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ defmodule Charms.JIT do

def init(module, opts \\ [])

def init({:module, module, binary, _}, opts) when is_atom(module) and is_binary(binary) do
init(module, opts)
end

def init(module, opts) when is_atom(module) do
name = opts[:name] || module
opts = Keyword.put_new(opts, :name, name)
Expand Down
24 changes: 24 additions & 0 deletions test/string_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
defmodule StringTest do
use ExUnit.Case

test "create str" do
defmodule SomeString do
use Charms
alias Charms.{Pointer, Term}

defm get(env) :: Term.t() do
str = "this is a string"
str = "this is a string"
term_ptr = Pointer.allocate(Term.t())
d_ptr = enif_make_new_binary(env, String.length(str), term_ptr)
m = ptr_to_memref(d_ptr)
memref.copy(str, m)
t = Pointer.load(Term.t(), term_ptr)
func.return(t)
end
end
|> Charms.JIT.init()

assert SomeString.get() == "this is a string"
end
end
Loading