Skip to content

Commit

Permalink
Update formatting for latest Ruff version
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Jan 13, 2025
1 parent 6afe94f commit 0222e36
Show file tree
Hide file tree
Showing 44 changed files with 84 additions and 89 deletions.
2 changes: 1 addition & 1 deletion benchmarks/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def next(prompt, state, index):
)
print("Time taken: ", time_taken)
res_handler.write(
f"{sampler},{execution_method}," f"{time_taken}\n"
f"{sampler},{execution_method},{time_taken}\n"
)
print()
print("*************************************")
Expand Down
4 changes: 2 additions & 2 deletions keras_hub/src/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def validate_and_fix_rank(inputs, tensor_name, base_rank=0):
return tf.squeeze(inputs, axis=-1)
else:
raise ValueError(
f"{tensor_name} must be of rank {base_rank}, {base_rank+1} "
f"or {base_rank+2}. Found rank: {inputs.shape.rank}"
f"{tensor_name} must be of rank {base_rank}, {base_rank + 1} "
f"or {base_rank + 2}. Found rank: {inputs.shape.rank}"
)

y_true = validate_and_fix_rank(y_true, "y_true", 1)
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/basnet/basnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def get_resnet_block(_resnet, block_num):
else:
x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]]
y = _resnet.get_layer(
f"stack{block_num}_block{num_blocks[block_num]-1}_add"
f"stack{block_num}_block{num_blocks[block_num] - 1}_add"
).output
return keras.models.Model(
inputs=x,
Expand Down
6 changes: 3 additions & 3 deletions keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ def build(self, input_shape):
dilation_rate=dilation_rate,
use_bias=False,
data_format=self.data_format,
name=f"aspp_conv_{i+2}",
name=f"aspp_conv_{i + 2}",
),
keras.layers.BatchNormalization(
axis=self.channel_axis, name=f"aspp_bn_{i+2}"
axis=self.channel_axis, name=f"aspp_bn_{i + 2}"
),
keras.layers.Activation(
self.activation, name=f"aspp_activation_{i+2}"
self.activation, name=f"aspp_activation_{i + 2}"
),
]
)
Expand Down
6 changes: 3 additions & 3 deletions keras_hub/src/models/densenet/densenet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def __init__(
channel_axis,
stackwise_num_repeats[stack_index],
growth_rate,
name=f"stack{stack_index+1}",
name=f"stack{stack_index + 1}",
)
pyramid_outputs[f"P{index}"] = x
x = apply_transition_block(
x,
channel_axis,
compression_ratio,
name=f"transition{stack_index+1}",
name=f"transition{stack_index + 1}",
)

x = apply_dense_block(
Expand Down Expand Up @@ -140,7 +140,7 @@ def apply_dense_block(x, channel_axis, num_repeats, growth_rate, name=None):

for i in range(num_repeats):
x = apply_conv_block(
x, channel_axis, growth_rate, name=f"{name}_block{i+1}"
x, channel_axis, growth_rate, name=f"{name}_block{i + 1}"
)
return x

Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/flux/flux_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(

def fit(self, *args, **kwargs):
raise NotImplementedError(
"Currently, `fit` is not supported for " "`FluxTextToImage`."
"Currently, `fit` is not supported for `FluxTextToImage`."
)

def generate_step(
Expand Down
4 changes: 2 additions & 2 deletions keras_hub/src/models/pali_gemma/pali_gemma_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"pali_gemma_3b_mix_224": {
"metadata": {
"description": (
"image size 224, mix fine tuned, text sequence " "length is 256"
"image size 224, mix fine tuned, text sequence length is 256"
),
"params": 2923335408,
"path": "pali_gemma",
Expand Down Expand Up @@ -45,7 +45,7 @@
"pali_gemma_3b_896": {
"metadata": {
"description": (
"image size 896, pre trained, text sequence length " "is 512"
"image size 896, pre trained, text sequence length is 512"
),
"params": 2927759088,
"path": "pali_gemma",
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/resnet/resnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
use_bias=False,
padding="same",
dtype=dtype,
name=f"conv{conv_index+1}_conv",
name=f"conv{conv_index + 1}_conv",
)(x)

if not use_pre_activation:
Expand Down
10 changes: 5 additions & 5 deletions keras_hub/src/models/retinanet/feature_pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def build(self, input_shapes):
)
if i == backbone_max_level + 1 and self.use_p5:
self.output_conv_layers[level].build(
(None, None, None, input_shapes[f"P{i-1}"][-1])
(None, None, None, input_shapes[f"P{i - 1}"][-1])
if self.data_format == "channels_last"
else (None, input_shapes[f"P{i-1}"][1], None, None)
else (None, input_shapes[f"P{i - 1}"][1], None, None)
)
else:
self.output_conv_layers[level].build(
Expand Down Expand Up @@ -277,7 +277,7 @@ def call(self, inputs):
if i < backbone_max_level:
# for the top most output, it doesn't need to merge with any
# upper stream outputs
upstream_output = self.top_down_op(output_features[f"P{i+1}"])
upstream_output = self.top_down_op(output_features[f"P{i + 1}"])
output = self.merge_op([output, upstream_output])
output_features[level] = (
self.lateral_batch_norm_layers[level](output)
Expand All @@ -296,9 +296,9 @@ def call(self, inputs):
for i in range(backbone_max_level + 1, self.max_level + 1):
level = f"P{i}"
feats_in = (
inputs[f"P{i-1}"]
inputs[f"P{i - 1}"]
if i == backbone_max_level + 1 and self.use_p5
else output_features[f"P{i-1}"]
else output_features[f"P{i - 1}"]
)
if i > backbone_max_level + 1:
feats_in = self.activation(feats_in)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ def __init__(

def fit(self, *args, **kwargs):
raise NotImplementedError(
"Currently, `fit` is not supported for "
"`StableDiffusion3Inpaint`."
"Currently, `fit` is not supported for `StableDiffusion3Inpaint`."
)

def generate_step(
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/models/vit/vit_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def build(self, input_shape):
attention_dropout=self.attention_dropout,
layer_norm_epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name=f"tranformer_block_{i+1}",
name=f"tranformer_block_{i + 1}",
)
encoder_block.build((None, None, self.hidden_dim))
self.encoder_layers.append(encoder_block)
Expand Down
3 changes: 1 addition & 2 deletions keras_hub/src/tokenizers/byte_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def __init__(
):
if not is_int_dtype(dtype):
raise ValueError(
"Output dtype must be an integer type. "
f"Received: dtype={dtype}"
f"Output dtype must be an integer type. Received: dtype={dtype}"
)

# Check normalization_form.
Expand Down
3 changes: 1 addition & 2 deletions keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ def __init__(
) -> None:
if not is_int_dtype(dtype):
raise ValueError(
"Output dtype must be an integer type. "
f"Received: dtype={dtype}"
f"Output dtype must be an integer type. Received: dtype={dtype}"
)

# Check normalization_form.
Expand Down
10 changes: 4 additions & 6 deletions keras_hub/src/utils/timm/convert_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,16 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
num_stacks = len(backbone.stackwise_num_repeats)
for stack_index in range(num_stacks):
for block_idx in range(backbone.stackwise_num_repeats[stack_index]):
keras_name = f"stack{stack_index+1}_block{block_idx+1}"
hf_name = (
f"features.denseblock{stack_index+1}.denselayer{block_idx+1}"
)
keras_name = f"stack{stack_index + 1}_block{block_idx + 1}"
hf_name = f"features.denseblock{stack_index + 1}.denselayer{block_idx + 1}"
port_batch_normalization(f"{keras_name}_1_bn", f"{hf_name}.norm1")
port_conv2d(f"{keras_name}_1_conv", f"{hf_name}.conv1")
port_batch_normalization(f"{keras_name}_2_bn", f"{hf_name}.norm2")
port_conv2d(f"{keras_name}_2_conv", f"{hf_name}.conv2")

for stack_index in range(num_stacks - 1):
keras_transition_name = f"transition{stack_index+1}"
hf_transition_name = f"features.transition{stack_index+1}"
keras_transition_name = f"transition{stack_index + 1}"
hf_transition_name = f"features.transition{stack_index + 1}"
port_batch_normalization(
f"{keras_transition_name}_bn", f"{hf_transition_name}.norm"
)
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/utils/timm/convert_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def port_batch_normalization(keras_layer, hf_weight_prefix):
# 97 is the start of the lowercase alphabet.
letter_identifier = chr(block_idx + 97)

keras_block_prefix = f"block{stack_index+1}{letter_identifier}_"
keras_block_prefix = f"block{stack_index + 1}{letter_identifier}_"
hf_block_prefix = f"blocks.{stack_index}.{block_idx}."

if block_type == "v1":
Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/utils/timm/convert_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
if version == "v1":
keras_name = f"stack{stack_index}_block{block_idx}"
hf_name = f"layer{stack_index+1}.{block_idx}"
hf_name = f"layer{stack_index + 1}.{block_idx}"
else:
keras_name = f"stack{stack_index}_block{block_idx}"
hf_name = f"stages.{stack_index}.blocks.{block_idx}"
Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_albert_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_bart_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
8 changes: 4 additions & 4 deletions tools/checkpoint_conversion/convert_bloom_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@


flags.DEFINE_string(
"preset", None, f'Must be one of {", ".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {', '.join(PRESET_MAP.keys())}"
)
flags.mark_flag_as_required("preset")
flags.DEFINE_boolean(
Expand Down Expand Up @@ -244,9 +244,9 @@ def preprocessor_call(input_str):

def main(_):
preset = FLAGS.preset
assert (
preset in PRESET_MAP.keys()
), f'Invalid preset {preset}. Must be one of {", ".join(PRESET_MAP.keys())}'
assert preset in PRESET_MAP.keys(), (
f"Invalid preset {preset}. Must be one of {', '.join(PRESET_MAP.keys())}"
)

validate_only = FLAGS.validate_only

Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_clip_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
flags.DEFINE_string(
"preset",
None,
f'Must be one of {",".join(PRESET_MAP.keys())}',
f"Must be one of {','.join(PRESET_MAP.keys())}",
required=True,
)
flags.DEFINE_string(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_electra_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
flags.DEFINE_string(
"preset",
"electra_base_discriminator_en",
f'Must be one of {",".join(PRESET_MAP)}',
f"Must be one of {','.join(PRESET_MAP)}",
)
flags.mark_flag_as_required("preset")

Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_f_net_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_falcon_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
absl.flags.DEFINE_string(
"preset",
"falcon_refinedweb_1b_en",
f'Must be one of {",".join(PRESET_MAP.keys())}.',
f"Must be one of {','.join(PRESET_MAP.keys())}.",
)


Expand Down
8 changes: 4 additions & 4 deletions tools/checkpoint_conversion/convert_gemma_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
flags.DEFINE_string(
"preset",
None,
f'Must be one of {",".join(PRESET_MAP.keys())}',
f"Must be one of {','.join(PRESET_MAP.keys())}",
required=True,
)

Expand Down Expand Up @@ -228,9 +228,9 @@ def main(_):
flax_dir = FLAGS.flax_dir
else:
presets = PRESET_MAP.keys()
assert (
preset in presets
), f'Invalid preset {preset}. Must be one of {",".join(presets)}'
assert preset in presets, (
f"Invalid preset {preset}. Must be one of {','.join(presets)}"
)
handle = PRESET_MAP[preset]
flax_dir = download_flax_model(handle)

Expand Down
6 changes: 3 additions & 3 deletions tools/checkpoint_conversion/convert_gpt2_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down Expand Up @@ -236,8 +236,8 @@ def check_output(

def main(_):
assert FLAGS.preset in PRESET_MAP.keys(), (
f'Invalid preset {FLAGS.preset}. '
f'Must be one of {",".join(PRESET_MAP.keys())}'
f"Invalid preset {FLAGS.preset}. "
f"Must be one of {','.join(PRESET_MAP.keys())}"
)
num_params = PRESET_MAP[FLAGS.preset][0]
hf_model_name = PRESET_MAP[FLAGS.preset][1]
Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_llama3_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_llama_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)

flags.DEFINE_string(
Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_mistral_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

FLAGS = flags.FLAGS
flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}'
"preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}"
)


Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint_conversion/convert_mix_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
}

flags.DEFINE_string(
"preset", None, f'Must be one of {",".join(DOWNLOAD_URLS.keys())}'
"preset", None, f"Must be one of {','.join(DOWNLOAD_URLS.keys())}"
)


Expand Down
Loading

0 comments on commit 0222e36

Please sign in to comment.