From d6bd728d9ed8deceff6f6c18e77e833a14e09a75 Mon Sep 17 00:00:00 2001 From: hy395 Date: Thu, 26 Dec 2024 14:24:21 -0800 Subject: [PATCH] only allow load trunk --- docs/transfer/transfer.md | 27 +++++++++++++---------- src/baskerville/scripts/hound_transfer.py | 8 +++---- src/baskerville/seqnn.py | 7 +----- 3 files changed, 19 insertions(+), 23 deletions(-) diff --git a/docs/transfer/transfer.md b/docs/transfer/transfer.md index e1b6641..472b726 100644 --- a/docs/transfer/transfer.md +++ b/docs/transfer/transfer.md @@ -19,15 +19,16 @@ w5_folder=${data_path}/w5 mkdir -p ${data_path} ``` -Download four replicate Borzoi pre-trained models (with identical train, validation and test splits (test = fold3, validation = fold4): +Download four replicate Borzoi pre-trained model trunks: ```bash -mkdir -p ${data_path}/weights -wget --progress=bar:force "https://storage.googleapis.com/seqnn-share/borzoi/f0/model0_best.h5" -O ${data_path}/weights/borzoi_r0.h5 -wget --progress=bar:force "https://storage.googleapis.com/seqnn-share/borzoi/f1/model0_best.h5" -O ${data_path}/weights/borzoi_r1.h5 -wget --progress=bar:force "https://storage.googleapis.com/seqnn-share/borzoi/f2/model0_best.h5" -O ${data_path}/weights/borzoi_r2.h5 -wget --progress=bar:force "https://storage.googleapis.com/seqnn-share/borzoi/f3/model0_best.h5" -O ${data_path}/weights/borzoi_r3.h5 +gsutil cp -r gs://scbasset_tutorial_data/baskerville_transfer/pretrain_trunks/ ${data_path} ``` +Note: +- Four replicate models have identical train, validation and test splits (test on fold3, validation on fold4, trained on rest). More details in the Borzoi manuscript. +- Fold splits can be found in trainsplit/sequences.bed. +- Model trunk refers to the model weights without the final dense layer (head). + Download hg38 reference information, and train-validation-test-split information: @@ -185,17 +186,19 @@ westminster_train_folds.py \ ${data_path}/tfr ``` -Run hound_transfer.py on fold3 data for 4 replicate models: +Run hound_transfer.py on training data in fold3 folder (identical to pre-train split) for four replicate models: ```bash -hound_transfer.py -o train_rep0 --restore ${data_path}/weights/borzoi_r0.h5 params.json train/f3c0/data0 -hound_transfer.py -o train_rep1 --restore ${data_path}/weights/borzoi_r1.h5 params.json train/f3c0/data0 -hound_transfer.py -o train_rep2 --restore ${data_path}/weights/borzoi_r2.h5 params.json train/f3c0/data0 -hound_transfer.py -o train_rep3 --restore ${data_path}/weights/borzoi_r3.h5 params.json train/f3c0/data0 +hound_transfer.py -o train_rep0 --trunk --restore ${data_path}/pretrain_trunks/borzoi_r0.h5 params.json train/f3c0/data0 +hound_transfer.py -o train_rep1 --trunk --restore ${data_path}/pretrain_trunks/borzoi_r1.h5 params.json train/f3c0/data0 +hound_transfer.py -o train_rep2 --trunk --restore ${data_path}/pretrain_trunks/borzoi_r2.h5 params.json train/f3c0/data0 +hound_transfer.py -o train_rep3 --trunk --restore ${data_path}/pretrain_trunks/borzoi_r3.h5 params.json train/f3c0/data0 ``` +Note: we recommend loading the model trunk only. While it is possible to load full Borzoi model and ignore last dense layer by model.load_weights(weight_file, skip_mismatch=True, by_name=True), Tensorflow requires loading layer weight by name in this way. If layer name don't match, weights of the layer will not be loaded and no warning message will be given. + ### Step 7. Load models We apply weight merging for lora, ia3, and locon weights, and so there is no architecture changes once the model is trained. You can use the same params.json file, and load the train_rep0/model_best.mergeW.h5 weight file. -For houlsby and houlsby_se, model architectures change due to the insertion of adapter modules. New architecture json file can be found in train_rep0/params.json. +For houlsby and houlsby_se, model architectures change due to the insertion of adapter modules. New architecture json file is auto-generated in train_rep0/params.json; diff --git a/src/baskerville/scripts/hound_transfer.py b/src/baskerville/scripts/hound_transfer.py index ab45dc4..e43e2f1 100755 --- a/src/baskerville/scripts/hound_transfer.py +++ b/src/baskerville/scripts/hound_transfer.py @@ -70,7 +70,7 @@ def main(): parser.add_argument( "--restore", default=None, - help="pre-trained weights.h5 [Default: %(default)s]", + help="model trunk h5 file [Default: %(default)s]", ) parser.add_argument( "--trunk", @@ -180,10 +180,8 @@ def main(): seqnn_model = seqnn.SeqNN(params_model) # restore - if args.trunk: + if args.restore: seqnn_model.restore(args.restore, trunk=args.trunk) - else: - seqnn_model.restore(args.restore, pretrain=True) # head params print( @@ -364,7 +362,7 @@ def main(): # restore if args.restore: - seqnn_model.restore(args.restore, args.trunk) + seqnn_model.restore(args.restore, trunk=args.trunk) # initialize trainer seqnn_trainer = trainer.Trainer( diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index d19d5b2..8887f50 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -1020,15 +1020,10 @@ def predict_transform( return preds - def restore(self, model_file, head_i=0, trunk=False, pretrain=False): + def restore(self, model_file, head_i=0, trunk=False): """Restore weights from saved model.""" if trunk: self.model_trunk.load_weights(model_file) - elif pretrain: - self.models[head_i].load_weights( - model_file, by_name=True, skip_mismatch=True - ) - self.model = self.models[head_i] else: self.models[head_i].load_weights(model_file) self.model = self.models[head_i]