You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
First of all, I am really not that familiar with Jax. My conda environment was built with distributed yml file and thus got the Jax 0.3.24 as shown below.
jax 0.3.24 pypi_0 pypi
jaxlib 0.3.24 pypi_0 pypi
However, when running the af2_interface_metrics with the silent files, I am getting the following error. Any thoughts on this? I am getting the same error when using both AF 2.3.1 and AF 2.2.4 versions. Also, af2_metrics.py works without any issue.
Traceback (most recent call last):
File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 597, in
predict_structure(tag_buffer, feature_dict_dict, binderlen_dict, initial_guess_dict, sfd_out, scorefilename)
File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 431, in predict_structure
prediction_result = jax.vmap(model_runner.apply, in_axes=(None,None,0,0))(model_runner.params,
File "/software/RFDesign2/envs/.conda/envs/rfdesign2/lib/python3.9/site-packages/haiku/_src/transform.py", line 128, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/software/RFDesign2/envs/.conda/envs/rfdesign2/lib/python3.9/site-packages/haiku/_src/transform.py", line 357, in apply_fn
out = f(*args, **kwargs)
TypeError: _forward_fn() takes 1 positional argument but 2 were given
Appreciate it a lot.
Thanks!
The text was updated successfully, but these errors were encountered:
module.AlphaFold is compiled by jax's jit, and the batch should be a dictionary with inputs to the AlphaFold model.
However, though _forward_fn() takes 1 positional argument, currently af2_interface_metrics.py line 431 passes 2 arguments?
Saying that since of course I am wrong as the tutorial was done with this code, I wonder what I am missing here. Again, I am pretty new for JAX and would appreciate if there is an idiot's 101 explanation for this. :)
TIA!
Hi,
First of all, I am really not that familiar with Jax. My conda environment was built with distributed yml file and thus got the Jax 0.3.24 as shown below.
jax 0.3.24 pypi_0 pypi
jaxlib 0.3.24 pypi_0 pypi
However, when running the af2_interface_metrics with the silent files, I am getting the following error. Any thoughts on this? I am getting the same error when using both AF 2.3.1 and AF 2.2.4 versions. Also, af2_metrics.py works without any issue.
Traceback (most recent call last):
File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 597, in
predict_structure(tag_buffer, feature_dict_dict, binderlen_dict, initial_guess_dict, sfd_out, scorefilename)
File "/software/RFDesign2/scripts/af2_interface_metrics.py", line 431, in predict_structure
prediction_result = jax.vmap(model_runner.apply, in_axes=(None,None,0,0))(model_runner.params,
File "/software/RFDesign2/envs/.conda/envs/rfdesign2/lib/python3.9/site-packages/haiku/_src/transform.py", line 128, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/software/RFDesign2/envs/.conda/envs/rfdesign2/lib/python3.9/site-packages/haiku/_src/transform.py", line 357, in apply_fn
out = f(*args, **kwargs)
TypeError: _forward_fn() takes 1 positional argument but 2 were given
Appreciate it a lot.
Thanks!
The text was updated successfully, but these errors were encountered: