Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove IREE solver #4585

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
14 changes: 0 additions & 14 deletions docs/source/user_guide/installation/gnu-linux-mac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,6 @@ Users can install ``jax`` and ``jaxlib`` to use the Jax solver.
The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system.

kratman marked this conversation as resolved.
Show resolved Hide resolved
PyBaMM's full `conda-forge distribution <index.rst#installation>`_ (``pybamm``) includes ``jax`` and ``jaxlib`` by default.

.. _optional-iree-mlir-support:

Optional - IREE / MLIR support
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Users can install ``iree`` (for MLIR just-in-time compilation) to use for main expression evaluation in the IDAKLU solver. Requires ``jax``.

.. code:: bash

pip install "pybamm[iree,jax]"

The ``pip install "pybamm[iree,jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``iree`` onto your system.

Uninstall PyBaMM
----------------

Expand Down
12 changes: 0 additions & 12 deletions docs/source/user_guide/installation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ Optional solvers
The following solvers are optionally available:

* `jax <https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>`_ -based solver, see :ref:`optional-jaxsolver` .
* `IREE <https://iree.dev/>`_ (`MLIR <https://mlir.llvm.org/>`_) support, see :ref:`optional-iree-mlir-support`.

Dependencies
------------
Expand Down Expand Up @@ -220,17 +219,6 @@ Dependency Minimu
`jaxlib <https://pypi.org/project/jaxlib/>`__ 0.4.20 jax Support library for JAX
========================================================================= ================== ================== =======================

IREE dependencies
^^^^^^^^^^^^^^^^^

Installable with ``pip install "pybamm[iree]"`` (requires ``jax`` dependencies to be installed).

========================================================================= ================== ================== =======================
Dependency Minimum Version pip extra Notes
========================================================================= ================== ================== =======================
`iree-compiler <https://iree.dev/>`__ 20240507.886 iree IREE compiler
========================================================================= ================== ================== =======================

Full installation guide
-----------------------

Expand Down
5 changes: 1 addition & 4 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def set_dev(session):
session.install("virtualenv", "cmake")
session.run("virtualenv", os.fsdecode(VENV_DIR), silent=True)
python = os.fsdecode(VENV_DIR.joinpath("bin/python"))
components = ["all", "dev", "jax"]
args = []
# Temporary fix for Python 3.12 CI. TODO: remove after
# https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with
# is fixed
Expand All @@ -129,8 +127,7 @@ def set_dev(session):
"pip",
"install",
"-e",
".[{}]".format(",".join(components)),
*args,
".[all,dev,jax]",
external=True,
)

Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,12 @@ dev = [
"importlib-metadata; python_version < '3.10'",
]
# For the Jax solver.
# Note: These must be kept in sync with the versions defined in pybamm/util.py, and
# must remain compatible with IREE (see noxfile.py for IREE compatibility).
# Note: These must be kept in sync with the versions defined in pybamm/util.py
jax = [
"jax==0.4.27",
"jaxlib==0.4.27",
]
# Contains all optional dependencies, except for jax, iree, and dev dependencies
# Contains all optional dependencies, except for jax and dev dependencies
all = [
"scikit-fem>=8.1.0",
"pybamm[examples,plot,cite,bpx,tqdm]",
Expand Down
5 changes: 1 addition & 4 deletions src/pybamm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from pybamm.version import __version__

# Demote expressions to 32-bit floats/ints - option used for IDAKLU-MLIR compilation
demote_expressions_to_32bit = False

# Utility classes and methods
from .util import root_dir
from .util import Timer, TimerTime, FuzzyDict
Expand Down Expand Up @@ -175,7 +172,7 @@
from .solvers.jax_bdf_solver import jax_bdf_integrate

from .solvers.idaklu_jax import IDAKLUJax
from .solvers.idaklu_solver import IDAKLUSolver, has_iree
from .solvers.idaklu_solver import IDAKLUSolver

# Experiments
from .experiment.experiment import Experiment
Expand Down
47 changes: 1 addition & 46 deletions src/pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,54 +596,9 @@ def __init__(self, symbol: pybamm.Symbol):
static_argnums=self._static_argnums,
)

def _demote_constants(self):
"""Demote 64-bit constants (f64, i64) to 32-bit (f32, i32)"""
if not pybamm.demote_expressions_to_32bit:
return # pragma: no cover
self._constants = EvaluatorJax._demote_64_to_32(self._constants)

@classmethod
def _demote_64_to_32(cls, c):
"""Demote 64-bit operations (f64, i64) to 32-bit (f32, i32)"""

if not pybamm.demote_expressions_to_32bit:
return c
if isinstance(c, float):
c = jax.numpy.float32(c)
if isinstance(c, int):
c = jax.numpy.int32(c)
if isinstance(c, np.int64):
c = c.astype(jax.numpy.int32)
if isinstance(c, np.ndarray):
if c.dtype == np.float64:
c = c.astype(jax.numpy.float32)
if c.dtype == np.int64:
c = c.astype(jax.numpy.int32)
if isinstance(c, jax.numpy.ndarray):
if c.dtype == jax.numpy.float64:
c = c.astype(jax.numpy.float32)
if c.dtype == jax.numpy.int64:
c = c.astype(jax.numpy.int32)
if isinstance(
c, pybamm.expression_tree.operations.evaluate_python.JaxCooMatrix
):
if c.data.dtype == np.float64:
c.data = c.data.astype(jax.numpy.float32)
if c.row.dtype == np.int64:
c.row = c.row.astype(jax.numpy.int32)
if c.col.dtype == np.int64:
c.col = c.col.astype(jax.numpy.int32)
if isinstance(c, dict):
c = {key: EvaluatorJax._demote_64_to_32(value) for key, value in c.items()}
if isinstance(c, tuple):
c = tuple(EvaluatorJax._demote_64_to_32(value) for value in c)
if isinstance(c, list):
c = [EvaluatorJax._demote_64_to_32(value) for value in c]
return c

@property
def _constants(self):
return tuple(map(EvaluatorJax._demote_64_to_32, self.__constants))
return self.__constants

@_constants.setter
def _constants(self, value):
Expand Down
Loading
Loading