Skip to content

Commit

Permalink
Use __IR__ in defintrinsic (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackalcooper authored Nov 30, 2024
1 parent 2d2b2ad commit 6752628
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 106 deletions.
63 changes: 14 additions & 49 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ defmodule Charms.Defm.Expander do

# expand the body
{body, _state, _env} = expand(body, state, env)
SCF.yield(List.last(body)) >>> []
SCF.yield(body) >>> []
end
end
end >>> result_t
Expand Down Expand Up @@ -465,46 +465,25 @@ defmodule Charms.Defm.Expander do
Pointer.load(Term.t(), term_ptr)
end
|> expand(state, env)
|> then(&{List.last(elem(&1, 0)), state, env})
end

defp expand_intrinsics(loc, module, intrinsic_impl, args, state, env) do
{args, state, env} = expand(args, state, env)
{params, state} = uniq_mlir_params(args, state)

v =
apply(module, intrinsic_impl, [
params,
args,
%Charms.Intrinsic.Opts{
ctx: state.mlir.ctx,
args: args,
block: state.mlir.blk,
loc: loc,
eval: fn ast ->
expand(
ast,
state,
env
)
end
loc: loc
}
])

case v do
%m{} when m in [MLIR.Value, MLIR.Type, MLIR.Operation] ->
{v, state, env}

f when is_function(f) ->
{f, state, env}

{:__block__, _, list} ->
# do not leak variables created in the macro
{v, _, _} = expand_list(list, state, env)

v
|> List.last()
|> then(&{&1, state, env})

ast = {_, _, _} ->
# do not leak variables created in the macro
{v, _state, _env} = expand(ast, state, env)
Expand Down Expand Up @@ -612,6 +591,11 @@ defmodule Charms.Defm.Expander do

defp expand({:__block__, _, list}, state, env) do
expand_list(list, state, env)
|> then(fn
{[], _, e} -> raise_compile_error(e, "empty block cannot be expanded")
{l, s, e} when is_list(l) -> {List.last(l), s, e}
{l, _, e} -> raise_compile_error(e, "Expected a list, got: #{inspect(l)}")
end)
end

## __aliases__
Expand Down Expand Up @@ -709,6 +693,11 @@ defmodule Charms.Defm.Expander do

try do
intrinsic_impl = Charms.Kernel.__intrinsics__(fun, length(args))

unless intrinsic_impl do
raise_compile_error(env, "intrinsic implementation not found for #{fun}/#{length(args)}")
end

expand_intrinsics(loc, Charms.Kernel, intrinsic_impl, args, state, env)
|> then(fn {v, _, _} ->
if is_list(v) do
Expand Down Expand Up @@ -795,22 +784,6 @@ defmodule Charms.Defm.Expander do
end
end

# Parameterized function call
defp expand(
{{:., _parameterized_meta, [parameterized]}, _meta, args},
state,
env
) do
{args, state, env} = expand(args, state, env)
{parameterized, state, env} = expand(parameterized, state, env)

if is_function(parameterized) do
{parameterized.(args), state, env}
else
raise_compile_error(env, "Expected a function, got: #{inspect(parameterized)}")
end
end

## Imported or local call

defp expand({fun, meta, args}, state, env) when is_atom(fun) and is_list(args) do
Expand Down Expand Up @@ -920,7 +893,7 @@ defmodule Charms.Defm.Expander do
mlir ctx: state.mlir.ctx, block: state.mlir.blk do
{ret, _, _} = expand(clause_body, state, env)

case ret |> List.wrap() |> List.last() do
case ret do
%MLIR.Operation{} ->
SCF.yield() >>> []
[]
Expand Down Expand Up @@ -1420,14 +1393,6 @@ defmodule Charms.Defm.Expander do
uniq_mlir_var() |> then(&{&1, put_mlir_var(state, &1, val)})
end

defp uniq_mlir_params(args, state) when is_list(args) do
for param <- args, reduce: {[], state} do
{params, %{mlir: _} = state} ->
{param, %{mlir: _} = state} = uniq_mlir_var(param, state)
{params ++ [param], state}
end
end

@doc """
Decomposes a call into the call part and return type.
"""
Expand Down
3 changes: 2 additions & 1 deletion lib/charms/env.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ defmodule Charms.Env do
use Charms.Intrinsic
alias Charms.Intrinsic.Opts

defintrinsic t(), %Opts{ctx: ctx} do
defintrinsic t() do
%Opts{ctx: ctx} = __IR__
Beaver.ENIF.Type.env(ctx: ctx)
end
end
33 changes: 21 additions & 12 deletions lib/charms/intrinsic.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@ defmodule Charms.Intrinsic do
@moduledoc """
Options for intrinsic functions.
"""
defstruct [:ctx, :args, :block, :loc, :eval]
defstruct [:ctx, :block, :loc]
end

@moduledoc """
Behaviour to define intrinsic functions.
"""
alias Beaver
@type opt :: {:ctx, MLIR.Context.t()} | {:block, MLIR.Block.t() | {:loc, MLIR.Location.t()}}
@type opts :: [opt | {atom(), term()}]
@type ir_return :: MLIR.Value.t() | MLIR.Operation.t()
@type intrinsic_return :: ir_return() | (any() -> ir_return())
Module.register_attribute(__MODULE__, :defintrinsic, accumulate: true)
Expand All @@ -25,12 +23,6 @@ defmodule Charms.Intrinsic do
end
end

defmacro defintrinsic(call, do: body) do
quote do
defintrinsic(unquote(call), %Charms.Intrinsic.Opts{}, do: unquote(body))
end
end

defp unwrap_unquote(name) do
case name do
{:unquote, _, [name]} ->
Expand All @@ -41,10 +33,15 @@ defmodule Charms.Intrinsic do
end
end

defp recompose_when_clauses(name, args, opts) do
defp recompose_when_clauses(name, args) do
intrinsic_name_ast =
{:unquote, [], [quote(do: :"__defintrinsic_#{unquote(unwrap_unquote(name))}__")]}

opts =
quote do
%Charms.Intrinsic.Opts{} = var!(charms_intrinsic_internal_opts)
end

case opts do
{:when, when_meta, [opts | clauses]} ->
{:when, when_meta,
Expand Down Expand Up @@ -81,9 +78,9 @@ defmodule Charms.Intrinsic do
@doc """
To implement an intrinsic function
"""
defmacro defintrinsic(call, opts, do: body) do
defmacro defintrinsic(call, do: body) do
{name, _meta, args} = call
call = recompose_when_clauses(name, args, opts)
call = recompose_when_clauses(name, args)
placeholder_args = normalize_arg_names(args)

# can't get the arity from length(args), because it might be an unquote_splicing, whose length is 1
Expand All @@ -96,10 +93,22 @@ defmodule Charms.Intrinsic do
end
end

body =
Macro.postwalk(body, fn
{:__IR__, _, args} when args == [] or args == nil ->
quote do
var!(charms_intrinsic_internal_opts)
end

ast ->
ast
end)

quote do
unquote(placeholder)
@doc false
def unquote(call) do
%Charms.Intrinsic.Opts{ctx: ctx} = var!(charms_intrinsic_internal_opts)
unquote(body)
end

Expand Down
11 changes: 6 additions & 5 deletions lib/charms/kernel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ defmodule Charms.Kernel do
end

for name <- @binary_ops ++ @binary_macro_ops do
defintrinsic unquote(name)(_left, _right),
opts = %Opts{args: [left, right], ctx: ctx, block: block, loc: loc} do
defintrinsic unquote(name)(left, right) do
opts = %Opts{ctx: ctx, block: block, loc: loc} = __IR__

{operands, type} =
case {left, right} do
{%MLIR.Value{} = v, i} when is_integer(i) ->
Expand All @@ -85,14 +86,14 @@ defmodule Charms.Kernel do
end
end

defintrinsic !_value, %Opts{args: [v]} do
t = MLIR.Value.type(v)
defintrinsic !value do
t = MLIR.Value.type(value)

unless MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do
raise ArgumentError, "Not an integer type to negate, unsupported type: #{to_string(t)}"
end

quote bind_quoted: [v: v, t: t] do
quote bind_quoted: [v: value, t: t] do
one = const 1 :: t
value arith.xori(v, one) :: t
end
Expand Down
44 changes: 22 additions & 22 deletions lib/charms/pointer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,63 @@ defmodule Charms.Pointer do
Allocates a single element of the given `elem_type`, returning a pointer to it.
"""
defintrinsic allocate(elem_type) do
quote do
Charms.Pointer.allocate(unquote(elem_type), 1)
quote bind_quoted: [elem_type: elem_type] do
Charms.Pointer.allocate(elem_type, 1)
end
end

@doc """
Allocates an array of `size` elements of the given `elem_type`, returning a pointer to it.
"""
defintrinsic allocate(elem_type, size), %Opts{ctx: ctx, args: [_elem_type, size_v]} do
defintrinsic allocate(elem_type, size) do
%Opts{ctx: ctx} = __IR__

cast =
case size_v do
case size do
i when is_integer(i) ->
quote do
const unquote(size_v) :: i64()
quote bind_quoted: [size: i] do
const size :: i64()
end

%MLIR.Value{} ->
if MLIR.equal?(MLIR.Value.type(size_v), Type.i64(ctx: ctx)) do
if MLIR.equal?(MLIR.Value.type(size), Type.i64(ctx: ctx)) do
size
else
quote do
value arith.extsi(unquote(size)) :: i64()
quote bind_quoted: [size: size] do
value arith.extsi(size) :: i64()
end
end
end

quote do
size = unquote(cast)
value llvm.alloca(size, elem_type: unquote(elem_type)) :: Pointer.t()
quote bind_quoted: [elem_type: elem_type, size: cast] do
value llvm.alloca(size, elem_type: elem_type) :: Pointer.t()
end
end

@doc """
Loads a value of `type` from the given pointer `ptr`.
"""
defintrinsic load(type, ptr) do
quote do
value llvm.load(unquote(ptr)) :: unquote(type)
quote bind_quoted: [type: type, ptr: ptr] do
value llvm.load(ptr) :: type
end
end

@doc """
Stores a value `val` at the given pointer `ptr`.
"""
defintrinsic store(val, ptr) do
quote do
llvm.store(unquote(val), unquote(ptr))
quote bind_quoted: [val: val, ptr: ptr] do
llvm.store(val, ptr)
end
end

@doc """
Gets the element pointer of `elem_type` for the given base pointer `ptr` and index `n`.
"""
defintrinsic element_ptr(_elem_type, _ptr, _n), %Opts{
ctx: ctx,
block: block,
args: [elem_type, ptr, n]
} do
defintrinsic element_ptr(elem_type, ptr, n) do
%Opts{ctx: ctx, block: block} = __IR__

mlir ctx: ctx, block: block do
LLVM.getelementptr(ptr, n,
elem_type: elem_type,
Expand All @@ -81,7 +80,8 @@ defmodule Charms.Pointer do
@doc """
Return the pointer type
"""
defintrinsic t(), %Opts{ctx: ctx} do
defintrinsic t() do
%Opts{ctx: ctx} = __IR__
Beaver.Deferred.create(~t{!llvm.ptr}, ctx)
end
end
29 changes: 19 additions & 10 deletions lib/charms/prelude.ex
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,29 @@ defmodule Charms.Prelude do
v
end

defintrinsic result_at(_entity, _index),
%Opts{args: [%MLIR.Operation{} = op, i]} when is_integer(i) do
defintrinsic result_at(%MLIR.Operation{} = op, index) do
num_results = MLIR.CAPI.mlirOperationGetNumResults(op)

if i < num_results do
MLIR.CAPI.mlirOperationGetResult(op, i)
if index < num_results do
MLIR.CAPI.mlirOperationGetResult(op, index)
else
raise ArgumentError,
"Index #{i} is out of bounds for operation results, num results: #{num_results}"
"Index #{index} is out of bounds for operation results, num results: #{num_results}"
end
end

defintrinsic type_of(_value), %Opts{args: [v]} do
MLIR.Value.type(v)
@doc """
Get the MLIR type of the given value.
"""
defintrinsic type_of(value) do
MLIR.Value.type(value)
end

@doc """
Dump the MLIR entity at compile time with `IO.puts/1`
"""
defintrinsic dump(entity) do
entity |> tap(&IO.puts(MLIR.to_string(&1)))
end

signature_ctx = MLIR.Context.create()
Expand All @@ -49,10 +58,10 @@ defmodule Charms.Prelude do
{arg_types, _} = Beaver.ENIF.signature(signature_ctx, name)
args = Macro.generate_arguments(length(arg_types), __MODULE__)

defintrinsic unquote(name)(unquote_splicing(args)),
opts = %Opts{args: args, ctx: ctx, block: block, loc: loc} do
defintrinsic unquote(name)(unquote_splicing(args)) do
opts = %Opts{ctx: ctx, block: block, loc: loc} = __IR__
{arg_types, ret_types} = Beaver.ENIF.signature(ctx, unquote(name))
args = args |> Enum.zip(arg_types) |> Enum.map(&wrap_arg(&1, opts))
args = [unquote_splicing(args)] |> Enum.zip(arg_types) |> Enum.map(&wrap_arg(&1, opts))

mlir ctx: ctx, block: block do
Func.call(args, callee: Attribute.flat_symbol_ref("#{unquote(name)}"), loc: loc) >>>
Expand Down
Loading

0 comments on commit 6752628

Please sign in to comment.