From 537b0ea11206bb50c8b53299f8e1176be9d90c16 Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Tue, 16 Jul 2024 13:05:20 +0200 Subject: [PATCH 01/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index da1cb08c..54f86d52 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -193,9 +193,14 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, _np, _softmax = jnp, jax.nn.softmax else: _np, _softmax = np, scipy.special.softmax + jax.debug.print('chain_num={chain_num}',chain_num=chain_num) if chain_num is None: chain_num = 1 + jax.debug.print('chain_num={chain_num}',chain_num=chain_num) + + jax.debug.print('asym_id={asym_id}',asym_id=asym_id) + jax.debug.print('asym_id_shape={asym_id_shape}',asym_id_shape=asym_id.shape) # residue_weights has to be in [0, 1], but can be floating-point, i.e. the # exp. resolved head's probability. @@ -225,6 +230,9 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, def get_cross_iptm(i, j): pair_mask = jnp.logical_and(i * jnp.ones((num_res))[:, None] == asym_id[None, :] , j*jnp.ones((num_res))[None, :] == asym_id[:, None]) + jax.debug.print('pair_mask={pair_mask}',pair_mask=pair_mask) + jax.debug.print('pair_mask_shape={pair_mask_shape}',pair_mask_shape=pair_mask.shape) + jax.debug.print('pair_mask_sum={pair_mask_sum}',pair_mask_sum=pair_mask_sum.sum()) chain_chain_predicted_tm_term = predicted_tm_term * pair_mask pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None]) normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True)) @@ -286,4 +294,5 @@ def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=F else: mean_score = confidence_metrics["mean_plddt"] confidence_metrics["ranking_confidence"] = mean_score - return confidence_metrics \ No newline at end of file + return confidence_metrics jax.debug.print('chain_num={chain_num}',chain_num=chain_num) + # jax.debug.print('residue weights = {x}',x=residue_weights) From 061c2656a28f25f5815f11082fc6a37ba2cbc871 Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Tue, 16 Jul 2024 13:25:24 +0200 Subject: [PATCH 02/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 1 - 1 file changed, 1 deletion(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 54f86d52..18071ddd 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -295,4 +295,3 @@ def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=F mean_score = confidence_metrics["mean_plddt"] confidence_metrics["ranking_confidence"] = mean_score return confidence_metrics jax.debug.print('chain_num={chain_num}',chain_num=chain_num) - # jax.debug.print('residue weights = {x}',x=residue_weights) From 0dc066117ef79b54876ea46587a825edede30b8b Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Tue, 16 Jul 2024 13:32:10 +0200 Subject: [PATCH 03/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 18071ddd..c60cc25e 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -294,4 +294,4 @@ def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=F else: mean_score = confidence_metrics["mean_plddt"] confidence_metrics["ranking_confidence"] = mean_score - return confidence_metrics jax.debug.print('chain_num={chain_num}',chain_num=chain_num) + return confidence_metrics From 257886b7e0a012eed2b13ae30b27930f224314d7 Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Tue, 16 Jul 2024 16:00:43 +0200 Subject: [PATCH 04/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index c60cc25e..04abb949 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -232,7 +232,7 @@ def get_cross_iptm(i, j): pair_mask = jnp.logical_and(i * jnp.ones((num_res))[:, None] == asym_id[None, :] , j*jnp.ones((num_res))[None, :] == asym_id[:, None]) jax.debug.print('pair_mask={pair_mask}',pair_mask=pair_mask) jax.debug.print('pair_mask_shape={pair_mask_shape}',pair_mask_shape=pair_mask.shape) - jax.debug.print('pair_mask_sum={pair_mask_sum}',pair_mask_sum=pair_mask_sum.sum()) + jax.debug.print('pair_mask_sum={pair_mask_sum}',pair_mask_sum=pair_mask.sum()) chain_chain_predicted_tm_term = predicted_tm_term * pair_mask pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None]) normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True)) From ab3109a0f2c67663c15d8dfd35a5bc8266505e7e Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Wed, 17 Jul 2024 01:05:57 +0200 Subject: [PATCH 05/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 04abb949..c4a6ed99 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -200,7 +200,11 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, jax.debug.print('chain_num={chain_num}',chain_num=chain_num) jax.debug.print('asym_id={asym_id}',asym_id=asym_id) - jax.debug.print('asym_id_shape={asym_id_shape}',asym_id_shape=asym_id.shape) + jax.debug.print('asym_id_shape={asym_id_shape}',asym_id_max=jnp.extract(asym_id.shape) + max_asym_id = asym_id.max() + jax.debug.print('max_asym_id={max_asym_id}',max_asym_id=max_asym_id) + max_asym_id_as_int = int(max_asym_id[0]) + jax.debug.print('max_asym_id_as_int={max_asym_id_as_int}',max_asym_id_as_int=max_asym_id_as_int) # residue_weights has to be in [0, 1], but can be floating-point, i.e. the # exp. resolved head's probability. @@ -241,7 +245,8 @@ def get_cross_iptm(i, j): iptm_matrix_list = [] - for i in jnp.arange(chain_num): +# for i in jnp.arange(chain_num): + for i in jnp.arange(max_asym_id_as_int): local_list = [] for j in jnp.arange(chain_num): local_list.append(get_cross_iptm(i, j)) From 474b1db3d33ed9aeafef8b4636db58f60948ddfb Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Wed, 17 Jul 2024 01:36:53 +0200 Subject: [PATCH 06/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 1 - 1 file changed, 1 deletion(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index c4a6ed99..b49d58c6 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -200,7 +200,6 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, jax.debug.print('chain_num={chain_num}',chain_num=chain_num) jax.debug.print('asym_id={asym_id}',asym_id=asym_id) - jax.debug.print('asym_id_shape={asym_id_shape}',asym_id_max=jnp.extract(asym_id.shape) max_asym_id = asym_id.max() jax.debug.print('max_asym_id={max_asym_id}',max_asym_id=max_asym_id) max_asym_id_as_int = int(max_asym_id[0]) From 0c4cd15ebda94be8b6e2781fcaac3819b8de9b16 Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Wed, 17 Jul 2024 11:18:05 +0200 Subject: [PATCH 07/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index b49d58c6..bfcfb52a 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -202,7 +202,7 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, jax.debug.print('asym_id={asym_id}',asym_id=asym_id) max_asym_id = asym_id.max() jax.debug.print('max_asym_id={max_asym_id}',max_asym_id=max_asym_id) - max_asym_id_as_int = int(max_asym_id[0]) + max_asym_id_as_int = int(max_asym_id) jax.debug.print('max_asym_id_as_int={max_asym_id_as_int}',max_asym_id_as_int=max_asym_id_as_int) # residue_weights has to be in [0, 1], but can be floating-point, i.e. the From dd3809e393257407598d7dabb49164bf6b7c2dc7 Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Wed, 17 Jul 2024 13:40:42 +0200 Subject: [PATCH 08/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index bfcfb52a..96283044 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -20,6 +20,9 @@ from alphafold.common import residue_constants import scipy.special +def apply_fn(batch): + max_asym_id_as_int = max_asym_id.astype(int) + def compute_tol(prev_pos, current_pos, mask, use_jnp=False): # Early stopping criteria based on criteria used in # AF2Complex: https://www.nature.com/articles/s41467-022-29394-2 @@ -202,7 +205,10 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, jax.debug.print('asym_id={asym_id}',asym_id=asym_id) max_asym_id = asym_id.max() jax.debug.print('max_asym_id={max_asym_id}',max_asym_id=max_asym_id) - max_asym_id_as_int = int(max_asym_id) + batch = {'asym_id': max_asym_id)} + apply_fn_jit = jax.jit(apply_fn) +# max_asym_id_as_int = int(max_asym_id) + max_asym_id_as_int = apply_fn_jit(batch) jax.debug.print('max_asym_id_as_int={max_asym_id_as_int}',max_asym_id_as_int=max_asym_id_as_int) # residue_weights has to be in [0, 1], but can be floating-point, i.e. the From 02e135f7f907fb38ae1132d7873be11f13860a0f Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Wed, 17 Jul 2024 13:53:24 +0200 Subject: [PATCH 09/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 96283044..db7fe42f 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -205,7 +205,7 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, jax.debug.print('asym_id={asym_id}',asym_id=asym_id) max_asym_id = asym_id.max() jax.debug.print('max_asym_id={max_asym_id}',max_asym_id=max_asym_id) - batch = {'asym_id': max_asym_id)} + batch = {'asym_id': max_asym_id} apply_fn_jit = jax.jit(apply_fn) # max_asym_id_as_int = int(max_asym_id) max_asym_id_as_int = apply_fn_jit(batch) From fe6542d5ad024d2f5329d34b373817de59ff397c Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Wed, 17 Jul 2024 18:36:34 +0200 Subject: [PATCH 10/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index db7fe42f..38e409de 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -21,7 +21,12 @@ import scipy.special def apply_fn(batch): - max_asym_id_as_int = max_asym_id.astype(int) + asym_id = batch['asym_id'] + jax.debug.print('asym_id={asym_id}', asym_id=asym_id) + max_asym_id = asym_id.max() + jax.debug.print('max_asym_id={max_asym_id}', max_asym_id=max_asym_id) + # Return the max_asym_id to be converted outside the JAX-traced function + return max_asym_id def compute_tol(prev_pos, current_pos, mask, use_jnp=False): # Early stopping criteria based on criteria used in @@ -203,13 +208,13 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, jax.debug.print('chain_num={chain_num}',chain_num=chain_num) jax.debug.print('asym_id={asym_id}',asym_id=asym_id) - max_asym_id = asym_id.max() - jax.debug.print('max_asym_id={max_asym_id}',max_asym_id=max_asym_id) - batch = {'asym_id': max_asym_id} + + batch = {'asym_id': asym_id} apply_fn_jit = jax.jit(apply_fn) -# max_asym_id_as_int = int(max_asym_id) - max_asym_id_as_int = apply_fn_jit(batch) - jax.debug.print('max_asym_id_as_int={max_asym_id_as_int}',max_asym_id_as_int=max_asym_id_as_int) + max_asym_id_traced = apply_fn_jit(batch) + max_asym_id_as_int = int(max_asym_id_traced) + +# jax.debug.print('max_asym_id_as_int={max_asym_id_as_int}',max_asym_id_as_int=max_asym_id_as_int) # residue_weights has to be in [0, 1], but can be floating-point, i.e. the # exp. resolved head's probability. From 4c5b1a1bd8fccbc239dcebee46fe9fb46c0c772b Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Wed, 17 Jul 2024 19:06:15 +0200 Subject: [PATCH 11/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 38e409de..0a648407 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -20,13 +20,13 @@ from alphafold.common import residue_constants import scipy.special -def apply_fn(batch): - asym_id = batch['asym_id'] - jax.debug.print('asym_id={asym_id}', asym_id=asym_id) - max_asym_id = asym_id.max() - jax.debug.print('max_asym_id={max_asym_id}', max_asym_id=max_asym_id) - # Return the max_asym_id to be converted outside the JAX-traced function - return max_asym_id +@jax.jit +def numpy_callback(x): + # Need to forward-declare the shape & dtype of the expected output. + result_shape = jax.core.ShapedArray(x.shape, x.dtype) + return jax.pure_callback(np.positive, result_shape, x) + + def compute_tol(prev_pos, current_pos, mask, use_jnp=False): # Early stopping criteria based on criteria used in @@ -208,11 +208,14 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, jax.debug.print('chain_num={chain_num}',chain_num=chain_num) jax.debug.print('asym_id={asym_id}',asym_id=asym_id) - - batch = {'asym_id': asym_id} - apply_fn_jit = jax.jit(apply_fn) - max_asym_id_traced = apply_fn_jit(batch) - max_asym_id_as_int = int(max_asym_id_traced) + chain_nums = numpy_callback(asym_id) # Return the max_asym_id to be converted outside the JAX-traced function + print('chain_nums=', chain_nums) + chain_num = max(chain_nums) + print('chain_num=', chain_num) +# batch = {'asym_id': asym_id} +# apply_fn_jit = jax.jit(apply_fn) +# max_asym_id_traced = apply_fn_jit(batch) +# max_asym_id_as_int = int(max_asym_id_traced) # jax.debug.print('max_asym_id_as_int={max_asym_id_as_int}',max_asym_id_as_int=max_asym_id_as_int) @@ -255,8 +258,7 @@ def get_cross_iptm(i, j): iptm_matrix_list = [] -# for i in jnp.arange(chain_num): - for i in jnp.arange(max_asym_id_as_int): + for i in jnp.arange(chain_num): local_list = [] for j in jnp.arange(chain_num): local_list.append(get_cross_iptm(i, j)) From 833376a0112bc9194fd23a3236ca9997c43590f0 Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Wed, 17 Jul 2024 19:54:16 +0200 Subject: [PATCH 12/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 0a648407..11445b0a 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -20,12 +20,9 @@ from alphafold.common import residue_constants import scipy.special -@jax.jit -def numpy_callback(x): - # Need to forward-declare the shape & dtype of the expected output. - result_shape = jax.core.ShapedArray(x.shape, x.dtype) - return jax.pure_callback(np.positive, result_shape, x) - +@partial(jit, static_argnums=1) +def func(x, axis): + return x.max(axis) def compute_tol(prev_pos, current_pos, mask, use_jnp=False): @@ -208,9 +205,7 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, jax.debug.print('chain_num={chain_num}',chain_num=chain_num) jax.debug.print('asym_id={asym_id}',asym_id=asym_id) - chain_nums = numpy_callback(asym_id) # Return the max_asym_id to be converted outside the JAX-traced function - print('chain_nums=', chain_nums) - chain_num = max(chain_nums) + chain_num = func(asym_id, -1) # Return the max_asym_id to be converted outside the JAX-traced function print('chain_num=', chain_num) # batch = {'asym_id': asym_id} # apply_fn_jit = jax.jit(apply_fn) From b771e05eb19fad894174c57c65d72db57f4b876b Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Wed, 17 Jul 2024 20:12:42 +0200 Subject: [PATCH 13/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 11445b0a..49746a18 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -13,7 +13,7 @@ # limitations under the License. """Functions for processing confidence metrics.""" - +from functools import partial import jax.numpy as jnp import jax import numpy as np From 8883211a1e9927417600fa9051210d476ea883db Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Wed, 17 Jul 2024 20:22:52 +0200 Subject: [PATCH 14/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 1 + 1 file changed, 1 insertion(+) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 49746a18..235c8ceb 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -14,6 +14,7 @@ """Functions for processing confidence metrics.""" from functools import partial +from jax import jit import jax.numpy as jnp import jax import numpy as np From 261350292d8ed40ebb26d917b508d441013159db Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Sun, 21 Jul 2024 15:49:38 +0200 Subject: [PATCH 15/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 235c8ceb..37d8c397 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -20,10 +20,19 @@ import numpy as np from alphafold.common import residue_constants import scipy.special +### +from jax.experimental import io_callback +from functools import partial + + +### +def func(x): + return x.max(-1).astype(x.dtype) + +@jax.jit +def numpy_random_like(x): + return io_callback(func, x, x) -@partial(jit, static_argnums=1) -def func(x, axis): - return x.max(axis) def compute_tol(prev_pos, current_pos, mask, use_jnp=False): @@ -206,7 +215,8 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, jax.debug.print('chain_num={chain_num}',chain_num=chain_num) jax.debug.print('asym_id={asym_id}',asym_id=asym_id) - chain_num = func(asym_id, -1) # Return the max_asym_id to be converted outside the JAX-traced function + chain_num= numpy_random_like(asym_id) + print('chain_num=', chain_num) # batch = {'asym_id': asym_id} # apply_fn_jit = jax.jit(apply_fn) From 7a0c6b0776d3608f30e3ddc108bcbd9a4a1c3b88 Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Sun, 21 Jul 2024 17:07:51 +0200 Subject: [PATCH 16/20] debugging output shape of iptm_chain --- alphafold/common/confidence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index 37d8c397..f27caf09 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -27,7 +27,7 @@ ### def func(x): - return x.max(-1).astype(x.dtype) + return x.max().astype(x.dtype) @jax.jit def numpy_random_like(x): From d9aa49638e76430eefadac0325b2c595f05086a9 Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Thu, 25 Jul 2024 17:07:47 +0200 Subject: [PATCH 17/20] fixed PR #9 --- alphafold/common/confidence.py | 20 +------------------- alphafold/model/model.py | 12 ++++++++---- alphafold/model/modules_multimer.py | 8 +++----- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index f27caf09..1d2c80cc 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -13,6 +13,7 @@ # limitations under the License. """Functions for processing confidence metrics.""" +from jax import debug from functools import partial from jax import jit import jax.numpy as jnp @@ -21,17 +22,7 @@ from alphafold.common import residue_constants import scipy.special ### -from jax.experimental import io_callback -from functools import partial - -### -def func(x): - return x.max().astype(x.dtype) - -@jax.jit -def numpy_random_like(x): - return io_callback(func, x, x) @@ -208,16 +199,10 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, _np, _softmax = jnp, jax.nn.softmax else: _np, _softmax = np, scipy.special.softmax - jax.debug.print('chain_num={chain_num}',chain_num=chain_num) if chain_num is None: chain_num = 1 - jax.debug.print('chain_num={chain_num}',chain_num=chain_num) - - jax.debug.print('asym_id={asym_id}',asym_id=asym_id) - chain_num= numpy_random_like(asym_id) - print('chain_num=', chain_num) # batch = {'asym_id': asym_id} # apply_fn_jit = jax.jit(apply_fn) # max_asym_id_traced = apply_fn_jit(batch) @@ -253,9 +238,6 @@ def predicted_tm_score_chain(logits, breaks, residue_weights = None, def get_cross_iptm(i, j): pair_mask = jnp.logical_and(i * jnp.ones((num_res))[:, None] == asym_id[None, :] , j*jnp.ones((num_res))[None, :] == asym_id[:, None]) - jax.debug.print('pair_mask={pair_mask}',pair_mask=pair_mask) - jax.debug.print('pair_mask_shape={pair_mask_shape}',pair_mask_shape=pair_mask.shape) - jax.debug.print('pair_mask_sum={pair_mask_sum}',pair_mask_sum=pair_mask.sum()) chain_chain_predicted_tm_term = predicted_tm_term * pair_mask pair_residue_weights = pair_mask * (residue_weights[None, :] * residue_weights[:, None]) normed_residue_mask = pair_residue_weights / (1e-8 + pair_residue_weights.sum(-1, keepdims=True)) diff --git a/alphafold/model/model.py b/alphafold/model/model.py index 88e90f1f..9327dd5f 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -35,15 +35,17 @@ class RunModel: def __init__(self, config: ml_collections.ConfigDict, params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None, - is_training = False): + is_training = False, + chain_num=1): self.config = config self.params = params self.multimer_mode = config.model.global_config.multimer_mode + self.chain_num = chain_num if self.multimer_mode: def _forward_fn(batch): - model = modules_multimer.AlphaFold(self.config.model) + model = modules_multimer.AlphaFold(self.config.model,self.chain_num) return model(batch, is_training=is_training) else: def _forward_fn(batch): @@ -134,6 +136,7 @@ def predict(self, Returns: A dictionary of model outputs. """ + self.init_params(feat) logging.info('Running predict with shape(feat) = %s', tree.map_structure(lambda x: x.shape, feat)) @@ -146,6 +149,7 @@ def predict(self, else: num_ensemble = self.config.data.eval.num_ensemble L = aatype.shape[1] + # initialize @@ -169,6 +173,7 @@ def _jnp_to_np(x): # initialize random key key = jax.random.PRNGKey(random_seed) + # iterate through recyckes for r in range(num_iters): @@ -197,6 +202,5 @@ def _jnp_to_np(x): break if r > 0 and result["tol"] < self.config.model.recycle_early_stop_tolerance: break - logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) - return result, r \ No newline at end of file + return result, r diff --git a/alphafold/model/modules_multimer.py b/alphafold/model/modules_multimer.py index 1e909c94..aa43adbf 100644 --- a/alphafold/model/modules_multimer.py +++ b/alphafold/model/modules_multimer.py @@ -405,11 +405,11 @@ class AlphaFold(hk.Module): """AlphaFold-Multimer model with recycling. """ - def __init__(self, config, name='alphafold'): + def __init__(self, config,chain_num, name='alphafold'): super().__init__(name=name) self.config = config self.global_config = config.global_config - + self.chain_num = chain_num def __call__( self, batch, @@ -418,6 +418,7 @@ def __call__( safe_key=None): c = self.config + chain_num = self.chain_num impl = AlphaFoldIteration(c, self.global_config) if safe_key is None: @@ -461,9 +462,6 @@ def apply_network(prev, safe_key): if not return_representations: del ret['representations'] - # Extract chain NUM - chain_num = c.embeddings_and_evoformer.max_relative_chain + 1 - # add confidence metrics ret.update(confidence.get_confidence_metrics( prediction_result=ret, From 9ef11e58d33a7c1dca9d62ca4b4585fb1e49cc0c Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Fri, 25 Oct 2024 15:52:08 +0200 Subject: [PATCH 18/20] updated version in setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 244b7503..8012afee 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name='alphafold-colabfold', - version='2.3.6', + version='2.3.8', long_description_content_type='text/markdown', description='An implementation of the inference pipeline of AlphaFold v2.3.1. ' 'This is a completely new model that was entered as AlphaFold2 in CASP14 ' From f9e40e93993cb6eae393492756589d4f7a78de3f Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Fri, 25 Oct 2024 22:22:30 +0200 Subject: [PATCH 19/20] updated version in setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8012afee..b076ac0f 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ author='DeepMind', author_email='alphafold@deepmind.com', license='Apache License, Version 2.0', - url='https://github.com/sokrypton/alphafold', + url='https://github.com/ntnn19/alphafold', packages=find_packages(), install_requires=[ 'absl-py', From 0f5ebef51b1534c3bf4561053b6ebbd0edff3578 Mon Sep 17 00:00:00 2001 From: ntnn19 Date: Fri, 25 Oct 2024 23:07:35 +0200 Subject: [PATCH 20/20] updated version in setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b076ac0f..b08a98bd 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ author='DeepMind', author_email='alphafold@deepmind.com', license='Apache License, Version 2.0', - url='https://github.com/ntnn19/alphafold', + url='https://github.com/ntnn19/alphafold/tree/chain_iptm', packages=find_packages(), install_requires=[ 'absl-py',