Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Charms.Constant.from_literal/5 #56

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions lib/charms/constant.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
defmodule Charms.Constant do
@moduledoc false
use Beaver
alias Beaver.MLIR.Dialect.{Arith, Index}

def from_literal(literal, %MLIR.Value{} = v, ctx, blk, loc) do
t = MLIR.CAPI.mlirValueGetType(v)
from_literal(literal, t, ctx, blk, loc)
end

def from_literal(literal, %MLIR.Type{} = t, ctx, blk, loc) do
mlir ctx: ctx, blk: blk do
cond do
MLIR.Type.integer?(t) ->
Arith.constant(value: Attribute.integer(t, literal), loc: loc) >>> t

MLIR.Type.float?(t) ->
Arith.constant(value: Attribute.float(t, literal), loc: loc) >>> t

MLIR.Type.index?(t) ->
Index.constant(value: Attribute.index(literal), loc: loc) >>> t

true ->
loc = Beaver.Deferred.create(loc, ctx)
raise CompileError, Charms.Diagnostic.meta_from_loc(loc) ++ [description: "Not a supported type for constant, #{to_string(t)}"]
end
end
end
end
27 changes: 15 additions & 12 deletions lib/charms/defm/definition.ex
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,21 @@ defmodule Charms.Defm.Definition do
"""
def compile(definitions) when is_list(definitions) do
ctx = MLIR.Context.create()
{res, msg} = MLIR.Context.with_diagnostics(
ctx,
fn ->
try do
{:ok, do_compile(ctx, definitions)}
rescue
err ->
{:error, err}
end
end,
fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end
)

{res, msg} =
MLIR.Context.with_diagnostics(
ctx,
fn ->
try do
{:ok, do_compile(ctx, definitions)}
rescue
err ->
{:error, err}
end
end,
fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end
)

case {res, msg} do
{{:ok, {mlir, mods}}, nil} ->
MLIR.Context.destroy(ctx)
Expand Down
12 changes: 1 addition & 11 deletions lib/charms/defm/expander.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1210,17 +1210,7 @@ defmodule Charms.Defm.Expander do
value =
mlir ctx: state.mlir.ctx, blk: state.mlir.blk do
loc = MLIR.Location.from_env(env)

cond do
MLIR.CAPI.mlirTypeIsAInteger(type) |> Beaver.Native.to_term() ->
Arith.constant(value: Attribute.integer(type, value), loc: loc) >>> type

MLIR.CAPI.mlirTypeIsAFloat(type) |> Beaver.Native.to_term() ->
Arith.constant(value: Attribute.float(type, value), loc: loc) >>> type

true ->
raise_compile_error(env, "Unsupported type for const macro: #{to_string(type)}")
end
Charms.Constant.from_literal(value, type, state.mlir.ctx, state.mlir.blk, loc)
end

{value, state, env}
Expand Down
16 changes: 8 additions & 8 deletions lib/charms/diagnostic.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ defmodule Charms.Diagnostic do
@moduledoc false
@doc false
alias Beaver.MLIR

def meta_from_loc(%MLIR.Location{} = loc) do
c = Regex.named_captures(~r/(?<file>.+):(?<line>\d+):(?<column>\d+)/, MLIR.to_string(loc))
[file: c["file"], line: c["line"] || 0]
end
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved

def compile_error_message(%Beaver.MLIR.Diagnostic{} = d) do
loc = to_string(MLIR.location(d))
txt = to_string(d)

case txt do
"" ->
{:error, "No diagnostic message"}

note ->
c =
Regex.named_captures(
~r/(?<file>.+):(?<line>\d+):(?<column>\d+)/,
loc
)

{:ok, [file: c["file"], line: c["line"] || 0, description: note]}
{:ok, meta_from_loc(MLIR.location(d)) ++ [description: note]}
end
end

Expand Down
39 changes: 24 additions & 15 deletions lib/charms/jit.ex
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ defmodule Charms.JIT do

defp do_init(modules) when is_list(modules) do
ctx = MLIR.Context.create()

modules
|> Enum.map(fn
m when is_atom(m) ->
Expand All @@ -81,21 +82,29 @@ defmodule Charms.JIT do
raise ArgumentError, "Unexpected module type: #{inspect(other)}"
end)
|> then(fn op ->
{res, _} = MLIR.Context.with_diagnostics(
ctx,
fn ->
try do
{:ok, op |> merge_modules() |> jit_of_mod()}
rescue
err ->
{:error, err}
end
end,
fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end
)
case res do
{:ok, jit} -> jit
{:error, err} -> raise err
{res, msg} =
MLIR.Context.with_diagnostics(
ctx,
fn ->
try do
{:ok, op |> merge_modules() |> jit_of_mod()}
rescue
err ->
{:error, err, __STACKTRACE__}
end
end,
fn d, _acc -> Charms.Diagnostic.compile_error_message(d) end
)

case {res, msg} do
{{:ok, jit}, nil} ->
jit

{{:error, _, st}, {:ok, d_msg}} ->
reraise CompileError, d_msg, st

{{:error, err, st}, _} ->
reraise err, st
end
end)
|> then(
Expand Down
23 changes: 7 additions & 16 deletions lib/charms/kernel.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,10 @@ defmodule Charms.Kernel do
use Charms.Intrinsic
alias Charms.Intrinsic.Opts
alias Beaver.MLIR.Dialect.Arith
@binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :*]
@binary_ops [:!=, :-, :+, :<, :>, :<=, :>=, :==, :*, :/]
@unary_ops [:!]
@binary_macro_ops [:&&, :||]

defp constant_of_same_type(i, v, %Opts{ctx: ctx, blk: blk, loc: loc}) do
mlir ctx: ctx, blk: blk do
t = MLIR.CAPI.mlirValueGetType(v)

if MLIR.CAPI.mlirTypeIsAInteger(t) |> Beaver.Native.to_term() do
Arith.constant(value: Attribute.integer(t, i), loc: loc) >>> t
else
raise ArgumentError, "Not an integer type for constant, #{to_string(t)}"
end
end
end

@compare_ops [:!=, :==, :>, :>=, :<, :<=]
defp i_predicate(:!=), do: :ne
defp i_predicate(:==), do: :eq
Expand Down Expand Up @@ -51,6 +39,9 @@ defmodule Charms.Kernel do
:* ->
Arith.muli(operands, loc: loc) >>> type

:/ ->
Arith.divsi(operands, loc: loc) >>> type

_ ->
raise ArgumentError, "Unsupported operator: #{inspect(op)}"
end
Expand All @@ -59,15 +50,15 @@ defmodule Charms.Kernel do

for name <- @binary_ops ++ @binary_macro_ops do
defintrinsic unquote(name)(left, right) do
opts = %Opts{ctx: ctx, blk: blk, loc: loc} = __IR__
%Opts{ctx: ctx, blk: blk, loc: loc} = __IR__

{operands, type} =
case {left, right} do
{%MLIR.Value{} = v, i} when is_integer(i) ->
[v, constant_of_same_type(i, v, opts)]
[v, Charms.Constant.from_literal(i, v, ctx, blk, loc)]

{i, %MLIR.Value{} = v} when is_integer(i) ->
[constant_of_same_type(i, v, opts), v]
[Charms.Constant.from_literal(i, v, ctx, blk, loc), v]

{%MLIR.Value{}, %MLIR.Value{}} ->
if not MLIR.equal?(MLIR.Value.type(left), MLIR.Value.type(right)) do
Expand Down
2 changes: 1 addition & 1 deletion test/const_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ defmodule ConstTest do
end

assert_raise CompileError,
~r"test/const_test.exs:13: Unsupported type for const macro: tensor<\*xf64>",
~r"test/const_test.exs:13: Not a supported type for constant, tensor<\*xf64>",
f
end
end
2 changes: 2 additions & 0 deletions test/defm_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ defmodule AddTwoInt do
a = Pointer.load(i32(), ptr_a)
b = Pointer.load(i32(), ptr_b)
sum = value llvm.add(a, b) :: i32()
sum = sum / 1
sum = sum + 1 - 1
term = enif_make_int(env, sum)
func.return(term)
else
Expand Down
Loading