Skip to content

Commit

Permalink
chore: update jaxns dependency (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve authored Dec 8, 2024
1 parent 831f1ae commit cd438e2
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"iminuit",
"jax>=0.4.28",
"jaxlib>=0.4.28",
"jaxns==2.6.3",
"jaxns==2.6.7",
"matplotlib",
"nautilus-sampler",
"numpy",
Expand Down
11 changes: 5 additions & 6 deletions src/elisa/infer/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
if TYPE_CHECKING:
from typing import Any, Callable, Literal

from jaxlib.xla_client import Device
from prettytable import PrettyTable

from elisa.infer.likelihood import Statistic
Expand Down Expand Up @@ -644,7 +645,7 @@ def jaxns(
s: int | None = None,
k: int | None = None,
c: int | None = None,
num_parallel_workers: int = 1,
devices: list[Device] | None = None,
difficult_model: bool = False,
parameter_estimation: bool = False,
verbose: bool = False,
Expand All @@ -671,8 +672,8 @@ def jaxns(
Number of parallel Markov chains. The default is 30 * `D`, where
`D` is the dimension of model parameters. It takes effect only
for num_live_points=None.
num_parallel_workers : int, optional
Parallel workers number. The default is 1.
devices : list, optional
Devices to use. Defaults to all available devices.
difficult_model : bool, optional
If True, uses more robust default settings (`s` = 10 and
`c` = 50 * `D`). It takes effect only for `num_live_points` = None,
Expand All @@ -699,15 +700,13 @@ def jaxns(
.. [1] `Phantom-Powered Nested Sampling <https://arxiv.org/abs/2312.11330>`__
.. [2] `JAXNS API doc <https://jaxns.readthedocs.io/en/latest/api/jaxns/index.html#jaxns.DefaultNestedSampler>`__
"""
num_parallel_workers = int(num_parallel_workers)

constructor_kwargs = {
'max_samples': max_samples,
'num_live_points': num_live_points,
's': s,
'k': k,
'c': c,
'num_parallel_workers': num_parallel_workers,
'devices': devices,
'difficult_model': difficult_model,
'parameter_estimation': parameter_estimation,
'verbose': verbose,
Expand Down
6 changes: 3 additions & 3 deletions src/elisa/infer/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def prior_model():

default_constructor_kwargs = dict(
num_live_points=model.U_ndims * 25,
num_parallel_workers=1,
devices=jax.devices(),
max_samples=1e4,
)
default_termination_kwargs = dict(dlogZ=1e-4)
Expand All @@ -272,8 +272,8 @@ def prior_model():
)

# TODO: check if this is necessary
# jit when num_parallel_workers is 1
if self.constructor_kwargs['num_parallel_workers'] == 1:
# jit when running on single device
if len(default_ns.nested_sampler.devices) == 1:
run_default_ns = jax.jit(default_ns)
else:
run_default_ns = default_ns
Expand Down

0 comments on commit cd438e2

Please sign in to comment.