Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collect referenced modules and init them together #34

Merged
merged 6 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ end

- run the benchmarks of sorting algorithms
```sh
mix run bench/sort.exs
mix run bench/sort_benchmark.exs
mix run bench/list_add_benchmark.exs
```
2 changes: 1 addition & 1 deletion bench/enif_quick_sort.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defmodule ENIFQuickSort do
@moduledoc false
use Charms, init: false
use Charms
alias Charms.{Pointer, Term, Env}

defm swap(a :: Pointer.t(), b :: Pointer.t()) do
Expand Down
2 changes: 1 addition & 1 deletion bench/enif_tim_sort.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defmodule ENIFTimSort do
@moduledoc false
use Charms, init: false
use Charms
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
alias Charms.{Pointer, Term, Env}

defm insertion_sort(arr :: Pointer.t(), left :: i32(), right :: i32()) do
Expand Down
5 changes: 0 additions & 5 deletions bench/list_add_benchmark.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
mod = AddTwoIntVec
Charms.JIT.init(mod)

a = b = Enum.to_list(1..10)
AddTwoIntVec.add(a, b, :err_msg)
AddTwoIntVec.dummy_load_no_make(a, b, :err_msg)
Expand All @@ -26,5 +23,3 @@ Benchee.run(
{a, b}
end
)

Charms.JIT.destroy(mod)
3 changes: 0 additions & 3 deletions bench/sort_benchmark.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
Charms.JIT.init(ENIFQuickSort)
Charms.JIT.init([ENIFTimSort, ENIFMergeSort])

arr = Enum.to_list(1..10000) |> Enum.shuffle()
ENIFQuickSort.sort(arr, :arg_err)
ENIFMergeSort.sort(arr, :arg_err)
Expand Down
2 changes: 1 addition & 1 deletion bench/vec_add_int_list.ex
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defmodule AddTwoIntVec do
use Charms, init: false
use Charms
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
alias Charms.{SIMD, Term, Pointer}

defm load_list(env, l :: Term.t()) :: SIMD.t(i32(), 8) do
Expand Down
13 changes: 12 additions & 1 deletion lib/charms.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,21 @@ defmodule Charms do

defmacro __before_compile__(_env) do
quote do
@ir @defm |> Enum.reverse() |> Charms.Defm.compile_definitions()
{ir, referenced_modules} = @defm |> Enum.reverse() |> Charms.Defm.compile_definitions()
@ir ir
@referenced_modules referenced_modules

@doc false
def __ir__ do
@ir
end

@doc false
def referenced_modules do
@referenced_modules
end

defoverridable referenced_modules: 0
end
end

Expand Down
34 changes: 33 additions & 1 deletion lib/charms/defm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,30 @@ defmodule Charms.Defm do
:ok
end

defp referenced_modules(module) do
Beaver.Walker.postwalk(module, MapSet.new(), fn
%MLIR.Operation{} = op, acc ->
with "func.call" <- MLIR.Operation.name(op),
callee when not is_nil(callee) <- Beaver.Walker.attributes(op)["callee"] do
case callee |> to_string do
"@Elixir." <> _ = name ->
acc |> MapSet.put(extract_mangled_mod(name))

_ ->
acc
end
|> then(&{op, &1})
else
_ ->
{op, acc}
end

ir, acc ->
{ir, acc}
end)
|> then(fn {_, acc} -> MapSet.to_list(acc) end)
end

@doc false
def compile_definitions(definitions) do
import MLIR.Transforms
Expand Down Expand Up @@ -186,12 +210,20 @@ defmodule Charms.Defm do
|> MLIR.Pass.Composer.append({"check-poison", "builtin.module", &check_poison!/1})
|> canonicalize
|> MLIR.Pass.Composer.run!(print: Charms.Debug.step_print?())
|> MLIR.to_string(bytecode: true)
|> then(&{MLIR.to_string(&1, bytecode: true), referenced_modules(&1)})
|> tap(fn _ -> MLIR.Context.destroy(ctx) end)
end

@doc false
def mangling(mod, func) do
Module.concat(mod, func)
end

defp extract_mangled_mod("@" <> name) do
name
|> String.split(".")
|> then(&Enum.take(&1, length(&1) - 1))
|> Enum.join(".")
|> String.to_atom()
end
end
52 changes: 31 additions & 21 deletions lib/charms/jit.ex
Original file line number Diff line number Diff line change
Expand Up @@ -90,36 +90,46 @@ defmodule Charms.JIT do
|> then(&{:ok, &1})
end

def init(module, opts \\ [])
defp collect_modules(module, acc \\ [])

def init({:module, module, binary, _}, opts) when is_atom(module) and is_binary(binary) do
init(module, opts)
end
defp collect_modules(module, acc) when is_atom(module) do
if module in acc do
acc
else
acc = [module | acc]

def init(module, opts) when is_atom(module) do
name = opts[:name] || module
opts = Keyword.put_new(opts, :name, name)
init([module], opts)
module.referenced_modules()
|> Enum.reduce(acc, fn m, acc ->
collect_modules(m, acc)
end)
end
end

def init(modules, opts) do
modules = modules |> List.wrap()
defp collect_modules(module, acc), do: [module | acc]

case {opts[:name], modules} do
{name, [_]} when not is_nil(name) ->
__MODULE__.LockedCache.run(name, fn -> do_init(modules) end)
def init(module, opts \\ [])

{nil, modules} when modules != [] ->
[key | tail] = modules
{:ok, jit} = __MODULE__.LockedCache.run(key, fn -> do_init(modules) end)
def init({:module, module, binary, _}, opts) when is_atom(module) and is_binary(binary) do
init(module, opts)
end

for module <- tail,
do:
__MODULE__.LockedCache.run(module, fn -> {:ok, %__MODULE__{jit | owner: false}} end)
def init(module, opts) do
name = opts[:name] || module

{name, modules} when not is_nil(name) and is_list(modules) ->
__MODULE__.LockedCache.run(name, fn -> do_init(modules) end)
{modules, jit} =
__MODULE__.LockedCache.run(name, fn ->
modules = collect_modules(module)
{:ok, jit} = do_init(modules)
{modules, jit}
end)

# modules will be nil if cache is hit
for m when is_atom(module) <- modules || [],
module != m do
__MODULE__.LockedCache.run(m, fn -> {:ok, %__MODULE__{jit | owner: false}} end)
end

{:ok, jit}
end

@doc """
Expand Down
1 change: 0 additions & 1 deletion test/const_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ defmodule ConstTest do
one = const 1.0 :: unranked_tensor(f64())
end
end
|> Charms.JIT.init()
end
end
end
2 changes: 0 additions & 2 deletions test/defm_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ defmodule DefmTest do
end

test "quick sort" do
Charms.JIT.init(ENIFQuickSort)
Charms.JIT.init([ENIFTimSort, ENIFMergeSort])
assert ENIFQuickSort.sort(:what, :arg_err) == :arg_err
arr = [5, 4, 3, 2, 1]
assert ENIFQuickSort.sort(arr, :arg_err) == Enum.sort(arr)
Expand Down
1 change: 0 additions & 1 deletion test/vec_add_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ defmodule VecAddTest do
use ExUnit.Case, async: true

test "vec add" do
{:ok, _} = Charms.JIT.init(AddTwoIntVec)
a = 1..8 |> Enum.to_list()
b = List.duplicate(1, 8)
assert AddTwoIntVec.add(a, b, :err) == Enum.to_list(2..9)
Expand Down
Loading