From ba65d74b74298592d96a0454e7330461ec955334 Mon Sep 17 00:00:00 2001 From: hy395 Date: Tue, 24 Dec 2024 15:27:42 -0800 Subject: [PATCH] black --- src/baskerville/scripts/hound_transfer.py | 4 ++-- src/baskerville/seqnn.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/baskerville/scripts/hound_transfer.py b/src/baskerville/scripts/hound_transfer.py index 7e8527d..ab45dc4 100755 --- a/src/baskerville/scripts/hound_transfer.py +++ b/src/baskerville/scripts/hound_transfer.py @@ -184,7 +184,7 @@ def main(): seqnn_model.restore(args.restore, trunk=args.trunk) else: seqnn_model.restore(args.restore, pretrain=True) - + # head params print( "params in new head: %d" @@ -281,7 +281,7 @@ def main(): if args.skip_train: exit(0) - + # train model if args.keras_fit: seqnn_trainer.fit_keras(seqnn_model) diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 610a176..d19d5b2 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -1025,7 +1025,9 @@ def restore(self, model_file, head_i=0, trunk=False, pretrain=False): if trunk: self.model_trunk.load_weights(model_file) elif pretrain: - self.models[head_i].load_weights(model_file, by_name=True, skip_mismatch=True) + self.models[head_i].load_weights( + model_file, by_name=True, skip_mismatch=True + ) self.model = self.models[head_i] else: self.models[head_i].load_weights(model_file)