From 0222e3664aa14cb5274bb6b77170e4965dba2b38 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Mon, 13 Jan 2025 14:28:36 -0800 Subject: [PATCH] Update formatting for latest Ruff version --- benchmarks/text_generation.py | 2 +- keras_hub/src/metrics/bleu.py | 4 ++-- keras_hub/src/models/basnet/basnet_backbone.py | 2 +- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py | 6 +++--- keras_hub/src/models/densenet/densenet_backbone.py | 6 +++--- keras_hub/src/models/flux/flux_text_to_image.py | 2 +- keras_hub/src/models/pali_gemma/pali_gemma_presets.py | 4 ++-- keras_hub/src/models/resnet/resnet_backbone.py | 2 +- keras_hub/src/models/retinanet/feature_pyramid.py | 10 +++++----- .../stable_diffusion_3/stable_diffusion_3_inpaint.py | 3 +-- keras_hub/src/models/vit/vit_layers.py | 2 +- keras_hub/src/tokenizers/byte_tokenizer.py | 3 +-- .../src/tokenizers/unicode_codepoint_tokenizer.py | 3 +-- keras_hub/src/utils/timm/convert_densenet.py | 10 ++++------ keras_hub/src/utils/timm/convert_efficientnet.py | 2 +- keras_hub/src/utils/timm/convert_resnet.py | 2 +- .../convert_albert_checkpoints.py | 2 +- .../checkpoint_conversion/convert_bart_checkpoints.py | 2 +- .../checkpoint_conversion/convert_bloom_checkpoints.py | 8 ++++---- .../checkpoint_conversion/convert_clip_checkpoints.py | 2 +- .../convert_deberta_v3_checkpoints.py | 2 +- .../convert_distilbert_checkpoints.py | 2 +- .../convert_electra_checkpoints.py | 2 +- .../checkpoint_conversion/convert_f_net_checkpoints.py | 2 +- .../convert_falcon_checkpoints.py | 2 +- .../checkpoint_conversion/convert_gemma_checkpoints.py | 8 ++++---- .../checkpoint_conversion/convert_gpt2_checkpoints.py | 6 +++--- .../convert_llama3_checkpoints.py | 2 +- .../checkpoint_conversion/convert_llama_checkpoints.py | 2 +- .../convert_mistral_checkpoints.py | 2 +- tools/checkpoint_conversion/convert_mix_transformer.py | 2 +- tools/checkpoint_conversion/convert_opt_checkpoints.py | 2 +- .../convert_pali_gemma2_checkpoints.py | 2 +- .../checkpoint_conversion/convert_phi3_checkpoints.py | 8 ++++---- .../convert_resnet_vd_checkpoints.py | 10 +++++----- .../convert_roberta_checkpoints.py | 6 +++--- .../convert_segformer_checkpoints.py | 2 +- .../convert_stable_diffusion_3_checkpoints.py | 6 +++--- tools/checkpoint_conversion/convert_t5_checkpoints.py | 2 +- tools/checkpoint_conversion/convert_vit_checkpoints.py | 2 +- .../convert_xlm_roberta_checkpoints.py | 6 +++--- tools/gemma/export_gemma_to_hf.py | 8 ++++---- tools/gemma/export_gemma_to_torch_xla.py | 8 ++++---- tools/gemma/run_gemma_xla.py | 2 +- 44 files changed, 84 insertions(+), 89 deletions(-) diff --git a/benchmarks/text_generation.py b/benchmarks/text_generation.py index cda95499e8..79811af718 100644 --- a/benchmarks/text_generation.py +++ b/benchmarks/text_generation.py @@ -150,7 +150,7 @@ def next(prompt, state, index): ) print("Time taken: ", time_taken) res_handler.write( - f"{sampler},{execution_method}," f"{time_taken}\n" + f"{sampler},{execution_method},{time_taken}\n" ) print() print("*************************************") diff --git a/keras_hub/src/metrics/bleu.py b/keras_hub/src/metrics/bleu.py index ced7a1ffd9..31727ee045 100644 --- a/keras_hub/src/metrics/bleu.py +++ b/keras_hub/src/metrics/bleu.py @@ -329,8 +329,8 @@ def validate_and_fix_rank(inputs, tensor_name, base_rank=0): return tf.squeeze(inputs, axis=-1) else: raise ValueError( - f"{tensor_name} must be of rank {base_rank}, {base_rank+1} " - f"or {base_rank+2}. Found rank: {inputs.shape.rank}" + f"{tensor_name} must be of rank {base_rank}, {base_rank + 1} " + f"or {base_rank + 2}. Found rank: {inputs.shape.rank}" ) y_true = validate_and_fix_rank(y_true, "y_true", 1) diff --git a/keras_hub/src/models/basnet/basnet_backbone.py b/keras_hub/src/models/basnet/basnet_backbone.py index 3ae15fc9bb..ff2e65fcb1 100644 --- a/keras_hub/src/models/basnet/basnet_backbone.py +++ b/keras_hub/src/models/basnet/basnet_backbone.py @@ -219,7 +219,7 @@ def get_resnet_block(_resnet, block_num): else: x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]] y = _resnet.get_layer( - f"stack{block_num}_block{num_blocks[block_num]-1}_add" + f"stack{block_num}_block{num_blocks[block_num] - 1}_add" ).output return keras.models.Model( inputs=x, diff --git a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py index 837e508d2c..7b55e3808c 100644 --- a/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +++ b/keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py @@ -88,13 +88,13 @@ def build(self, input_shape): dilation_rate=dilation_rate, use_bias=False, data_format=self.data_format, - name=f"aspp_conv_{i+2}", + name=f"aspp_conv_{i + 2}", ), keras.layers.BatchNormalization( - axis=self.channel_axis, name=f"aspp_bn_{i+2}" + axis=self.channel_axis, name=f"aspp_bn_{i + 2}" ), keras.layers.Activation( - self.activation, name=f"aspp_activation_{i+2}" + self.activation, name=f"aspp_activation_{i + 2}" ), ] ) diff --git a/keras_hub/src/models/densenet/densenet_backbone.py b/keras_hub/src/models/densenet/densenet_backbone.py index 6efddf4d20..4663ea0946 100644 --- a/keras_hub/src/models/densenet/densenet_backbone.py +++ b/keras_hub/src/models/densenet/densenet_backbone.py @@ -81,14 +81,14 @@ def __init__( channel_axis, stackwise_num_repeats[stack_index], growth_rate, - name=f"stack{stack_index+1}", + name=f"stack{stack_index + 1}", ) pyramid_outputs[f"P{index}"] = x x = apply_transition_block( x, channel_axis, compression_ratio, - name=f"transition{stack_index+1}", + name=f"transition{stack_index + 1}", ) x = apply_dense_block( @@ -140,7 +140,7 @@ def apply_dense_block(x, channel_axis, num_repeats, growth_rate, name=None): for i in range(num_repeats): x = apply_conv_block( - x, channel_axis, growth_rate, name=f"{name}_block{i+1}" + x, channel_axis, growth_rate, name=f"{name}_block{i + 1}" ) return x diff --git a/keras_hub/src/models/flux/flux_text_to_image.py b/keras_hub/src/models/flux/flux_text_to_image.py index 29cb302a1c..f1eab83a85 100644 --- a/keras_hub/src/models/flux/flux_text_to_image.py +++ b/keras_hub/src/models/flux/flux_text_to_image.py @@ -81,7 +81,7 @@ def __init__( def fit(self, *args, **kwargs): raise NotImplementedError( - "Currently, `fit` is not supported for " "`FluxTextToImage`." + "Currently, `fit` is not supported for `FluxTextToImage`." ) def generate_step( diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py index 39f28643a6..615cc6c5ec 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_presets.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_presets.py @@ -5,7 +5,7 @@ "pali_gemma_3b_mix_224": { "metadata": { "description": ( - "image size 224, mix fine tuned, text sequence " "length is 256" + "image size 224, mix fine tuned, text sequence length is 256" ), "params": 2923335408, "path": "pali_gemma", @@ -45,7 +45,7 @@ "pali_gemma_3b_896": { "metadata": { "description": ( - "image size 896, pre trained, text sequence length " "is 512" + "image size 896, pre trained, text sequence length is 512" ), "params": 2927759088, "path": "pali_gemma", diff --git a/keras_hub/src/models/resnet/resnet_backbone.py b/keras_hub/src/models/resnet/resnet_backbone.py index 407ce44f5b..86592002e5 100644 --- a/keras_hub/src/models/resnet/resnet_backbone.py +++ b/keras_hub/src/models/resnet/resnet_backbone.py @@ -177,7 +177,7 @@ def __init__( use_bias=False, padding="same", dtype=dtype, - name=f"conv{conv_index+1}_conv", + name=f"conv{conv_index + 1}_conv", )(x) if not use_pre_activation: diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index 1d60db238b..3bdc5a17b1 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -209,9 +209,9 @@ def build(self, input_shapes): ) if i == backbone_max_level + 1 and self.use_p5: self.output_conv_layers[level].build( - (None, None, None, input_shapes[f"P{i-1}"][-1]) + (None, None, None, input_shapes[f"P{i - 1}"][-1]) if self.data_format == "channels_last" - else (None, input_shapes[f"P{i-1}"][1], None, None) + else (None, input_shapes[f"P{i - 1}"][1], None, None) ) else: self.output_conv_layers[level].build( @@ -277,7 +277,7 @@ def call(self, inputs): if i < backbone_max_level: # for the top most output, it doesn't need to merge with any # upper stream outputs - upstream_output = self.top_down_op(output_features[f"P{i+1}"]) + upstream_output = self.top_down_op(output_features[f"P{i + 1}"]) output = self.merge_op([output, upstream_output]) output_features[level] = ( self.lateral_batch_norm_layers[level](output) @@ -296,9 +296,9 @@ def call(self, inputs): for i in range(backbone_max_level + 1, self.max_level + 1): level = f"P{i}" feats_in = ( - inputs[f"P{i-1}"] + inputs[f"P{i - 1}"] if i == backbone_max_level + 1 and self.use_p5 - else output_features[f"P{i-1}"] + else output_features[f"P{i - 1}"] ) if i > backbone_max_level + 1: feats_in = self.activation(feats_in) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py index 7c1c0ba7ab..6f2cc85f86 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py @@ -82,8 +82,7 @@ def __init__( def fit(self, *args, **kwargs): raise NotImplementedError( - "Currently, `fit` is not supported for " - "`StableDiffusion3Inpaint`." + "Currently, `fit` is not supported for `StableDiffusion3Inpaint`." ) def generate_step( diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 8cdc52ca71..f3509440d5 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -351,7 +351,7 @@ def build(self, input_shape): attention_dropout=self.attention_dropout, layer_norm_epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, - name=f"tranformer_block_{i+1}", + name=f"tranformer_block_{i + 1}", ) encoder_block.build((None, None, self.hidden_dim)) self.encoder_layers.append(encoder_block) diff --git a/keras_hub/src/tokenizers/byte_tokenizer.py b/keras_hub/src/tokenizers/byte_tokenizer.py index d3bd04aa1b..ebb61c5b91 100644 --- a/keras_hub/src/tokenizers/byte_tokenizer.py +++ b/keras_hub/src/tokenizers/byte_tokenizer.py @@ -150,8 +150,7 @@ def __init__( ): if not is_int_dtype(dtype): raise ValueError( - "Output dtype must be an integer type. " - f"Received: dtype={dtype}" + f"Output dtype must be an integer type. Received: dtype={dtype}" ) # Check normalization_form. diff --git a/keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py b/keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py index 35218e7f65..b63b29baa8 100644 --- a/keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +++ b/keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py @@ -203,8 +203,7 @@ def __init__( ) -> None: if not is_int_dtype(dtype): raise ValueError( - "Output dtype must be an integer type. " - f"Received: dtype={dtype}" + f"Output dtype must be an integer type. Received: dtype={dtype}" ) # Check normalization_form. diff --git a/keras_hub/src/utils/timm/convert_densenet.py b/keras_hub/src/utils/timm/convert_densenet.py index d7baa1cbe6..a9bca88b4b 100644 --- a/keras_hub/src/utils/timm/convert_densenet.py +++ b/keras_hub/src/utils/timm/convert_densenet.py @@ -59,18 +59,16 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): num_stacks = len(backbone.stackwise_num_repeats) for stack_index in range(num_stacks): for block_idx in range(backbone.stackwise_num_repeats[stack_index]): - keras_name = f"stack{stack_index+1}_block{block_idx+1}" - hf_name = ( - f"features.denseblock{stack_index+1}.denselayer{block_idx+1}" - ) + keras_name = f"stack{stack_index + 1}_block{block_idx + 1}" + hf_name = f"features.denseblock{stack_index + 1}.denselayer{block_idx + 1}" port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.norm1") port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1") port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.norm2") port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2") for stack_index in range(num_stacks - 1): - keras_transition_name = f"transition{stack_index+1}" - hf_transition_name = f"features.transition{stack_index+1}" + keras_transition_name = f"transition{stack_index + 1}" + hf_transition_name = f"features.transition{stack_index + 1}" port_batch_normalization( f"{keras_transition_name}_bn", f"{hf_transition_name}.norm" ) diff --git a/keras_hub/src/utils/timm/convert_efficientnet.py b/keras_hub/src/utils/timm/convert_efficientnet.py index a8ca08c773..140cbee1d0 100644 --- a/keras_hub/src/utils/timm/convert_efficientnet.py +++ b/keras_hub/src/utils/timm/convert_efficientnet.py @@ -268,7 +268,7 @@ def port_batch_normalization(keras_layer, hf_weight_prefix): # 97 is the start of the lowercase alphabet. letter_identifier = chr(block_idx + 97) - keras_block_prefix = f"block{stack_index+1}{letter_identifier}_" + keras_block_prefix = f"block{stack_index + 1}{letter_identifier}_" hf_block_prefix = f"blocks.{stack_index}.{block_idx}." if block_type == "v1": diff --git a/keras_hub/src/utils/timm/convert_resnet.py b/keras_hub/src/utils/timm/convert_resnet.py index 0d2a8cbfd3..779091b062 100644 --- a/keras_hub/src/utils/timm/convert_resnet.py +++ b/keras_hub/src/utils/timm/convert_resnet.py @@ -89,7 +89,7 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): for block_idx in range(backbone.stackwise_num_blocks[stack_index]): if version == "v1": keras_name = f"stack{stack_index}_block{block_idx}" - hf_name = f"layer{stack_index+1}.{block_idx}" + hf_name = f"layer{stack_index + 1}.{block_idx}" else: keras_name = f"stack{stack_index}_block{block_idx}" hf_name = f"stages.{stack_index}.blocks.{block_idx}" diff --git a/tools/checkpoint_conversion/convert_albert_checkpoints.py b/tools/checkpoint_conversion/convert_albert_checkpoints.py index 5336762bb4..c450d28747 100644 --- a/tools/checkpoint_conversion/convert_albert_checkpoints.py +++ b/tools/checkpoint_conversion/convert_albert_checkpoints.py @@ -20,7 +20,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) diff --git a/tools/checkpoint_conversion/convert_bart_checkpoints.py b/tools/checkpoint_conversion/convert_bart_checkpoints.py index 49fa465d8f..9563348a86 100644 --- a/tools/checkpoint_conversion/convert_bart_checkpoints.py +++ b/tools/checkpoint_conversion/convert_bart_checkpoints.py @@ -18,7 +18,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) diff --git a/tools/checkpoint_conversion/convert_bloom_checkpoints.py b/tools/checkpoint_conversion/convert_bloom_checkpoints.py index 5903030398..7251249ffc 100644 --- a/tools/checkpoint_conversion/convert_bloom_checkpoints.py +++ b/tools/checkpoint_conversion/convert_bloom_checkpoints.py @@ -46,7 +46,7 @@ flags.DEFINE_string( - "preset", None, f'Must be one of {", ".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {', '.join(PRESET_MAP.keys())}" ) flags.mark_flag_as_required("preset") flags.DEFINE_boolean( @@ -244,9 +244,9 @@ def preprocessor_call(input_str): def main(_): preset = FLAGS.preset - assert ( - preset in PRESET_MAP.keys() - ), f'Invalid preset {preset}. Must be one of {", ".join(PRESET_MAP.keys())}' + assert preset in PRESET_MAP.keys(), ( + f"Invalid preset {preset}. Must be one of {', '.join(PRESET_MAP.keys())}" + ) validate_only = FLAGS.validate_only diff --git a/tools/checkpoint_conversion/convert_clip_checkpoints.py b/tools/checkpoint_conversion/convert_clip_checkpoints.py index 113c037e25..861992f7e3 100644 --- a/tools/checkpoint_conversion/convert_clip_checkpoints.py +++ b/tools/checkpoint_conversion/convert_clip_checkpoints.py @@ -65,7 +65,7 @@ flags.DEFINE_string( "preset", None, - f'Must be one of {",".join(PRESET_MAP.keys())}', + f"Must be one of {','.join(PRESET_MAP.keys())}", required=True, ) flags.DEFINE_string( diff --git a/tools/checkpoint_conversion/convert_deberta_v3_checkpoints.py b/tools/checkpoint_conversion/convert_deberta_v3_checkpoints.py index b464747994..c456c755c8 100644 --- a/tools/checkpoint_conversion/convert_deberta_v3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_deberta_v3_checkpoints.py @@ -27,7 +27,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) diff --git a/tools/checkpoint_conversion/convert_distilbert_checkpoints.py b/tools/checkpoint_conversion/convert_distilbert_checkpoints.py index b3a67c6c7d..04250386ab 100644 --- a/tools/checkpoint_conversion/convert_distilbert_checkpoints.py +++ b/tools/checkpoint_conversion/convert_distilbert_checkpoints.py @@ -23,7 +23,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) diff --git a/tools/checkpoint_conversion/convert_electra_checkpoints.py b/tools/checkpoint_conversion/convert_electra_checkpoints.py index bfcb3e4573..c3ed71a850 100644 --- a/tools/checkpoint_conversion/convert_electra_checkpoints.py +++ b/tools/checkpoint_conversion/convert_electra_checkpoints.py @@ -33,7 +33,7 @@ flags.DEFINE_string( "preset", "electra_base_discriminator_en", - f'Must be one of {",".join(PRESET_MAP)}', + f"Must be one of {','.join(PRESET_MAP)}", ) flags.mark_flag_as_required("preset") diff --git a/tools/checkpoint_conversion/convert_f_net_checkpoints.py b/tools/checkpoint_conversion/convert_f_net_checkpoints.py index b102b1e680..31ec7ba0ad 100644 --- a/tools/checkpoint_conversion/convert_f_net_checkpoints.py +++ b/tools/checkpoint_conversion/convert_f_net_checkpoints.py @@ -17,7 +17,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) diff --git a/tools/checkpoint_conversion/convert_falcon_checkpoints.py b/tools/checkpoint_conversion/convert_falcon_checkpoints.py index 2ebe2e88fd..68d2ca2578 100644 --- a/tools/checkpoint_conversion/convert_falcon_checkpoints.py +++ b/tools/checkpoint_conversion/convert_falcon_checkpoints.py @@ -43,7 +43,7 @@ absl.flags.DEFINE_string( "preset", "falcon_refinedweb_1b_en", - f'Must be one of {",".join(PRESET_MAP.keys())}.', + f"Must be one of {','.join(PRESET_MAP.keys())}.", ) diff --git a/tools/checkpoint_conversion/convert_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_gemma_checkpoints.py index a41598d70c..e500a8202b 100644 --- a/tools/checkpoint_conversion/convert_gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_gemma_checkpoints.py @@ -57,7 +57,7 @@ flags.DEFINE_string( "preset", None, - f'Must be one of {",".join(PRESET_MAP.keys())}', + f"Must be one of {','.join(PRESET_MAP.keys())}", required=True, ) @@ -228,9 +228,9 @@ def main(_): flax_dir = FLAGS.flax_dir else: presets = PRESET_MAP.keys() - assert ( - preset in presets - ), f'Invalid preset {preset}. Must be one of {",".join(presets)}' + assert preset in presets, ( + f"Invalid preset {preset}. Must be one of {','.join(presets)}" + ) handle = PRESET_MAP[preset] flax_dir = download_flax_model(handle) diff --git a/tools/checkpoint_conversion/convert_gpt2_checkpoints.py b/tools/checkpoint_conversion/convert_gpt2_checkpoints.py index 882ddc3a64..00bdba477e 100644 --- a/tools/checkpoint_conversion/convert_gpt2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_gpt2_checkpoints.py @@ -28,7 +28,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) @@ -236,8 +236,8 @@ def check_output( def main(_): assert FLAGS.preset in PRESET_MAP.keys(), ( - f'Invalid preset {FLAGS.preset}. ' - f'Must be one of {",".join(PRESET_MAP.keys())}' + f"Invalid preset {FLAGS.preset}. " + f"Must be one of {','.join(PRESET_MAP.keys())}" ) num_params = PRESET_MAP[FLAGS.preset][0] hf_model_name = PRESET_MAP[FLAGS.preset][1] diff --git a/tools/checkpoint_conversion/convert_llama3_checkpoints.py b/tools/checkpoint_conversion/convert_llama3_checkpoints.py index d8433705b3..41c6c7ad77 100644 --- a/tools/checkpoint_conversion/convert_llama3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_llama3_checkpoints.py @@ -22,7 +22,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) diff --git a/tools/checkpoint_conversion/convert_llama_checkpoints.py b/tools/checkpoint_conversion/convert_llama_checkpoints.py index 3a5b25bb02..e3af5a600a 100644 --- a/tools/checkpoint_conversion/convert_llama_checkpoints.py +++ b/tools/checkpoint_conversion/convert_llama_checkpoints.py @@ -33,7 +33,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) flags.DEFINE_string( diff --git a/tools/checkpoint_conversion/convert_mistral_checkpoints.py b/tools/checkpoint_conversion/convert_mistral_checkpoints.py index aada40adc0..5d5ac3f2a1 100644 --- a/tools/checkpoint_conversion/convert_mistral_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mistral_checkpoints.py @@ -25,7 +25,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) diff --git a/tools/checkpoint_conversion/convert_mix_transformer.py b/tools/checkpoint_conversion/convert_mix_transformer.py index 7ab51768ac..d20c3cc003 100644 --- a/tools/checkpoint_conversion/convert_mix_transformer.py +++ b/tools/checkpoint_conversion/convert_mix_transformer.py @@ -52,7 +52,7 @@ } flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(DOWNLOAD_URLS.keys())}' + "preset", None, f"Must be one of {','.join(DOWNLOAD_URLS.keys())}" ) diff --git a/tools/checkpoint_conversion/convert_opt_checkpoints.py b/tools/checkpoint_conversion/convert_opt_checkpoints.py index 67bbf10987..835151b4dd 100644 --- a/tools/checkpoint_conversion/convert_opt_checkpoints.py +++ b/tools/checkpoint_conversion/convert_opt_checkpoints.py @@ -19,7 +19,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) diff --git a/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py b/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py index 72f18f6fd3..cf8beb1902 100644 --- a/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py +++ b/tools/checkpoint_conversion/convert_pali_gemma2_checkpoints.py @@ -84,7 +84,7 @@ flags.DEFINE_string( "preset", None, - f'Must be one of {",".join(PRESET_MAP.keys())}', + f"Must be one of {','.join(PRESET_MAP.keys())}", required=True, ) flags.DEFINE_string( diff --git a/tools/checkpoint_conversion/convert_phi3_checkpoints.py b/tools/checkpoint_conversion/convert_phi3_checkpoints.py index 9b2a2d7c79..8b8a53137f 100644 --- a/tools/checkpoint_conversion/convert_phi3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_phi3_checkpoints.py @@ -67,9 +67,9 @@ def convert_tokenizer(hf_model_dir): model_proto.pieces.append(new_token) tokenizer = Phi3Tokenizer(model_proto.SerializeToString()) for key, value in added_tokens.items(): - assert key == tokenizer.id_to_token( - value - ), f"{key} token have different id in the tokenizer" + assert key == tokenizer.id_to_token(value), ( + f"{key} token have different id in the tokenizer" + ) return tokenizer @@ -324,7 +324,7 @@ def main(): default="phi3_mini_4k_instruct_en", choices=PRESET_MAP.keys(), required=True, - help=f'Preset must be one of {", ".join(PRESET_MAP.keys())}', + help=f"Preset must be one of {', '.join(PRESET_MAP.keys())}", ) def device_regex(arg_value, pattern=re.compile(r"^cpu$|^cuda:[0-9]+$")): diff --git a/tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py b/tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py index a8a424f494..00f9865049 100644 --- a/tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py +++ b/tools/checkpoint_conversion/convert_resnet_vd_checkpoints.py @@ -120,8 +120,8 @@ def map_residual_layer_name(name: str): # in PaddleClas. first try a mapping in the form # stack2_block3_1_conv -> res4b2_branch2a paddle_address = ( - f'{int(match["stack"])+2}b{int(match["block"])}' - f'_branch{branch_mapping[int(match["conv"])]}' + f"{int(match['stack']) + 2}b{int(match['block'])}" + f"_branch{branch_mapping[int(match['conv'])]}" ) if match["type"] == "bn": paddle_name = f"bn{paddle_address}" @@ -133,8 +133,8 @@ def map_residual_layer_name(name: str): # if that was not successful, try a mapping like # stack2_block3_1_conv -> res4c_branch2a paddle_address = ( - f'{int(match["stack"])+2}{"abcdefghijkl"[int(match["block"])]}' - f'_branch{branch_mapping[int(match["conv"])]}' + f"{int(match['stack']) + 2}{'abcdefghijkl'[int(match['block'])]}" + f"_branch{branch_mapping[int(match['conv'])]}" ) if match["type"] == "bn": paddle_name = f"bn{paddle_address}" @@ -278,7 +278,7 @@ def set_dense_layer( "elephant.jpg", ) -print(f'{"Model": <25}Error') +print(f"{'Model': <25}Error") for architecture_name in configurations: # PaddleClas prediction predictor = paddleclas.PaddleClas(model_name=architecture_name).predictor diff --git a/tools/checkpoint_conversion/convert_roberta_checkpoints.py b/tools/checkpoint_conversion/convert_roberta_checkpoints.py index 7dba81abf0..e67b488973 100644 --- a/tools/checkpoint_conversion/convert_roberta_checkpoints.py +++ b/tools/checkpoint_conversion/convert_roberta_checkpoints.py @@ -24,7 +24,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) @@ -310,8 +310,8 @@ def check_output( def main(_): assert FLAGS.preset in PRESET_MAP.keys(), ( - f'Invalid preset {FLAGS.preset}. ' - f'Must be one of {",".join(PRESET_MAP.keys())}' + f"Invalid preset {FLAGS.preset}. " + f"Must be one of {','.join(PRESET_MAP.keys())}" ) size = PRESET_MAP[FLAGS.preset][0] hf_model_name = PRESET_MAP[FLAGS.preset][1] diff --git a/tools/checkpoint_conversion/convert_segformer_checkpoints.py b/tools/checkpoint_conversion/convert_segformer_checkpoints.py index 23b92a6b16..d87d590e71 100644 --- a/tools/checkpoint_conversion/convert_segformer_checkpoints.py +++ b/tools/checkpoint_conversion/convert_segformer_checkpoints.py @@ -46,7 +46,7 @@ } flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(DOWNLOAD_URLS.keys())}' + "preset", None, f"Must be one of {','.join(DOWNLOAD_URLS.keys())}" ) diff --git a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py index 5ee81bad82..272fce6040 100644 --- a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py @@ -98,7 +98,7 @@ flags.DEFINE_string( "preset", None, - f'Must be one of {",".join(PRESET_MAP.keys())}', + f"Must be one of {','.join(PRESET_MAP.keys())}", required=True, ) flags.DEFINE_string( @@ -571,7 +571,7 @@ def port_attention(loader, keras_variable, hf_weight_key): for i, _ in enumerate(decoder.stackwise_num_filters): for j in range(decoder.stackwise_num_blocks[i]): n = len(decoder.stackwise_num_blocks) - 1 - prefix = f"decoder.up.{n-i}.block.{j}" + prefix = f"decoder.up.{n - i}.block.{j}" port_resnet_block( loader, decoder.blocks[blocks_idx], prefix ) @@ -580,7 +580,7 @@ def port_attention(loader, keras_variable, hf_weight_key): port_conv2d( loader, decoder.upsamples[upsamples_idx + 1], - f"decoder.up.{n-i}.upsample.conv", + f"decoder.up.{n - i}.upsample.conv", ) upsamples_idx += 2 # Skip `UpSampling2D`. diff --git a/tools/checkpoint_conversion/convert_t5_checkpoints.py b/tools/checkpoint_conversion/convert_t5_checkpoints.py index 9eb609a474..842bda61c0 100644 --- a/tools/checkpoint_conversion/convert_t5_checkpoints.py +++ b/tools/checkpoint_conversion/convert_t5_checkpoints.py @@ -102,7 +102,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", "t5_base_multi", f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", "t5_base_multi", f"Must be one of {','.join(PRESET_MAP.keys())}" ) os.environ["KERAS_BACKEND"] = "torch" diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py index 0777535229..991127461b 100644 --- a/tools/checkpoint_conversion/convert_vit_checkpoints.py +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -46,7 +46,7 @@ flags.DEFINE_string( "preset", None, - f'Must be one of {",".join(PRESET_MAP.keys())}', + f"Must be one of {','.join(PRESET_MAP.keys())}", required=True, ) flags.DEFINE_string( diff --git a/tools/checkpoint_conversion/convert_xlm_roberta_checkpoints.py b/tools/checkpoint_conversion/convert_xlm_roberta_checkpoints.py index 88b764b1f3..84e6c5eaf4 100644 --- a/tools/checkpoint_conversion/convert_xlm_roberta_checkpoints.py +++ b/tools/checkpoint_conversion/convert_xlm_roberta_checkpoints.py @@ -23,7 +23,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) @@ -291,8 +291,8 @@ def check_output( def main(_): assert FLAGS.preset in PRESET_MAP.keys(), ( - f'Invalid preset {FLAGS.preset}. ' - f'Must be one of {",".join(PRESET_MAP.keys())}' + f"Invalid preset {FLAGS.preset}. " + f"Must be one of {','.join(PRESET_MAP.keys())}" ) size = PRESET_MAP[FLAGS.preset][0] hf_model_name = PRESET_MAP[FLAGS.preset][1] diff --git a/tools/gemma/export_gemma_to_hf.py b/tools/gemma/export_gemma_to_hf.py index d364c659d6..a632fbd498 100644 --- a/tools/gemma/export_gemma_to_hf.py +++ b/tools/gemma/export_gemma_to_hf.py @@ -129,7 +129,7 @@ flags.DEFINE_string( "preset", None, - f'Must be one of {",".join(PRESET_MAP.keys())}' + f"Must be one of {','.join(PRESET_MAP.keys())}" " Alternatively, a Keras weights file (`.weights.h5`) can be passed" " to --weights_file flag.", ) @@ -345,9 +345,9 @@ def convert_checkpoints( def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None: """Updates the state dict for a weight given a tensor.""" - assert ( - tensor.shape == layer.state_dict()[weight_name].shape - ), f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + assert tensor.shape == layer.state_dict()[weight_name].shape, ( + f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + ) layer.state_dict()[weight_name].copy_(tensor) diff --git a/tools/gemma/export_gemma_to_torch_xla.py b/tools/gemma/export_gemma_to_torch_xla.py index c6e43677b0..d6793e87f0 100644 --- a/tools/gemma/export_gemma_to_torch_xla.py +++ b/tools/gemma/export_gemma_to_torch_xla.py @@ -77,7 +77,7 @@ flags.DEFINE_string( "preset", None, - f'Must be one of {",".join(PRESET_MAP.keys())}' + f"Must be one of {','.join(PRESET_MAP.keys())}" " Alternatively, a Keras weights file (`.weights.h5`) can be passed" " to --weights_file flag.", ) @@ -270,9 +270,9 @@ def convert_checkpoints(preset, weights_file, size, output_file, vocab_dir): def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None: """Updates the state dict for a weight given a tensor.""" - assert ( - tensor.shape == layer.state_dict()[weight_name].shape - ), f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + assert tensor.shape == layer.state_dict()[weight_name].shape, ( + f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + ) layer.state_dict()[weight_name].copy_(tensor) diff --git a/tools/gemma/run_gemma_xla.py b/tools/gemma/run_gemma_xla.py index 1f07912184..1a7501c1e6 100644 --- a/tools/gemma/run_gemma_xla.py +++ b/tools/gemma/run_gemma_xla.py @@ -92,7 +92,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" ) flags.DEFINE_string( "size",