Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collect referenced modules and init them together #34

Merged
merged 6 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ end

- run the benchmarks of sorting algorithms
```sh
mix run bench/sort.exs
mix run bench/sort_benchmark.exs
mix run bench/list_add_benchmark.exs
```
2 changes: 1 addition & 1 deletion bench/enif_quick_sort.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defmodule ENIFQuickSort do
@moduledoc false
use Charms, init: false
use Charms
alias Charms.{Pointer, Term, Env}

defm swap(a :: Pointer.t(), b :: Pointer.t()) do
Expand Down
2 changes: 1 addition & 1 deletion bench/enif_tim_sort.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defmodule ENIFTimSort do
@moduledoc false
use Charms, init: false
use Charms
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
alias Charms.{Pointer, Term, Env}

defm insertion_sort(arr :: Pointer.t(), left :: i32(), right :: i32()) do
Expand Down
5 changes: 0 additions & 5 deletions bench/list_add_benchmark.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
mod = AddTwoIntVec
Charms.JIT.init(mod)

a = b = Enum.to_list(1..10)
AddTwoIntVec.add(a, b, :err_msg)
AddTwoIntVec.dummy_load_no_make(a, b, :err_msg)
Expand All @@ -26,5 +23,3 @@ Benchee.run(
{a, b}
end
)

Charms.JIT.destroy(mod)
3 changes: 0 additions & 3 deletions bench/sort_benchmark.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
Charms.JIT.init(ENIFQuickSort)
Charms.JIT.init([ENIFTimSort, ENIFMergeSort])

arr = Enum.to_list(1..10000) |> Enum.shuffle()
ENIFQuickSort.sort(arr, :arg_err)
ENIFMergeSort.sort(arr, :arg_err)
Expand Down
2 changes: 1 addition & 1 deletion bench/vec_add_int_list.ex
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defmodule AddTwoIntVec do
use Charms, init: false
use Charms
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
alias Charms.{SIMD, Term, Pointer}

defm load_list(env, l :: Term.t()) :: SIMD.t(i32(), 8) do
Expand Down
9 changes: 8 additions & 1 deletion lib/charms.ex
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,17 @@ defmodule Charms do

defmacro __before_compile__(_env) do
quote do
@ir @defm |> Enum.reverse() |> Charms.Defm.compile_definitions()
{ir, referenced_modules} = @defm |> Enum.reverse() |> Charms.Defm.compile_definitions()
@ir ir
@referenced_modules referenced_modules

def __ir__ do
@ir
end

def __referenced_modules__ do
@referenced_modules
end
end
end

Expand Down
29 changes: 28 additions & 1 deletion lib/charms/defm.ex
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,33 @@ defmodule Charms.Defm do
:ok
end

defp referenced_modules(module) do
Beaver.Walker.postwalk(module, MapSet.new(), fn
%MLIR.Operation{} = op, acc ->
if MLIR.Operation.name(op) == "func.call" do
callee = Beaver.Walker.attributes(op)["callee"]
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved

acc =
case callee |> to_string do
"@Elixir." <> name ->
[m, _f] = name |> String.split(".")
acc |> MapSet.put(String.to_atom("Elixir.#{m}"))

_ ->
acc
end

{op, acc}
else
{op, acc}
end

ir, acc ->
{ir, acc}
end)
|> then(fn {_, acc} -> MapSet.to_list(acc) end)
end
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved

@doc false
def compile_definitions(definitions) do
import MLIR.Transforms
Expand Down Expand Up @@ -186,7 +213,7 @@ defmodule Charms.Defm do
|> MLIR.Pass.Composer.append({"check-poison", "builtin.module", &check_poison!/1})
|> canonicalize
|> MLIR.Pass.Composer.run!(print: Charms.Debug.step_print?())
|> MLIR.to_string(bytecode: true)
|> then(&{MLIR.to_string(&1, bytecode: true), referenced_modules(&1)})
|> tap(fn _ -> MLIR.Context.destroy(ctx) end)
end

Expand Down
40 changes: 35 additions & 5 deletions lib/charms/jit.ex
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ defmodule Charms.JIT do
|> then(&{:ok, &1})
end

defp collect_modules(module, acc \\ [])

defp collect_modules(module, acc) when is_atom(module) do
acc = [module | acc]

for m <- module.__referenced_modules__(), m not in acc do
collect_modules(m, acc)
end
|> List.flatten()
|> Enum.concat(acc)
|> Enum.uniq()
end

defp collect_modules(module, _acc), do: [module]

jackalcooper marked this conversation as resolved.
Show resolved Hide resolved
def init(module, opts \\ [])

def init({:module, module, binary, _}, opts) when is_atom(module) and is_binary(binary) do
Expand All @@ -106,16 +121,31 @@ defmodule Charms.JIT do
modules = modules |> List.wrap()

case {opts[:name], modules} do
{name, [_]} when not is_nil(name) ->
__MODULE__.LockedCache.run(name, fn -> do_init(modules) end)
{name, [m]} when not is_nil(name) ->
{modules, jit} =
__MODULE__.LockedCache.run(name, fn ->
modules = collect_modules(m)
{:ok, jit} = do_init(modules)
{modules, jit}
end)

# modules will be nil if cache is hit
modules = modules || []

for module when is_atom(module) <- modules,
module != m do
__MODULE__.LockedCache.run(module, fn -> {:ok, %__MODULE__{jit | owner: false}} end)
end

{:ok, jit}
jackalcooper marked this conversation as resolved.
Show resolved Hide resolved

{nil, modules} when modules != [] ->
[key | tail] = modules
{:ok, jit} = __MODULE__.LockedCache.run(key, fn -> do_init(modules) end)

for module <- tail,
do:
__MODULE__.LockedCache.run(module, fn -> {:ok, %__MODULE__{jit | owner: false}} end)
for module <- tail do
__MODULE__.LockedCache.run(module, fn -> {:ok, %__MODULE__{jit | owner: false}} end)
end

{name, modules} when not is_nil(name) and is_list(modules) ->
__MODULE__.LockedCache.run(name, fn -> do_init(modules) end)
Expand Down
1 change: 0 additions & 1 deletion test/const_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ defmodule ConstTest do
one = const 1.0 :: unranked_tensor(f64())
end
end
|> Charms.JIT.init()
end
end
end
2 changes: 0 additions & 2 deletions test/defm_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ defmodule DefmTest do
end

test "quick sort" do
Charms.JIT.init(ENIFQuickSort)
Charms.JIT.init([ENIFTimSort, ENIFMergeSort])
assert ENIFQuickSort.sort(:what, :arg_err) == :arg_err
arr = [5, 4, 3, 2, 1]
assert ENIFQuickSort.sort(arr, :arg_err) == Enum.sort(arr)
Expand Down
1 change: 0 additions & 1 deletion test/vec_add_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ defmodule VecAddTest do
use ExUnit.Case, async: true

test "vec add" do
{:ok, _} = Charms.JIT.init(AddTwoIntVec)
a = 1..8 |> Enum.to_list()
b = List.duplicate(1, 8)
assert AddTwoIntVec.add(a, b, :err) == Enum.to_list(2..9)
Expand Down
Loading