Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding RegNets to tf.keras.applications #15702

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions keras/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ keras_packages = [
"keras.applications.mobilenet_v2",
"keras.applications.mobilenet_v3",
"keras.applications.nasnet",
"keras.applications.regnet",
"keras.applications.resnet",
"keras.applications.resnet_v2",
"keras.applications.vgg16",
Expand Down
2 changes: 2 additions & 0 deletions keras/api/api_init_files.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ KERAS_API_INIT_FILES = [
"keras/applications/mobilenet_v2/__init__.py",
"keras/applications/mobilenet_v3/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/regnet/__init__.py",
"keras/applications/resnet/__init__.py",
"keras/applications/resnet50/__init__.py",
"keras/applications/resnet_v2/__init__.py",
Expand Down Expand Up @@ -85,6 +86,7 @@ KERAS_API_INIT_FILES_V1 = [
"keras/applications/mobilenet_v2/__init__.py",
"keras/applications/mobilenet_v3/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/regnet/__init__.py",
"keras/applications/resnet/__init__.py",
"keras/applications/resnet_v2/__init__.py",
"keras/applications/resnet50/__init__.py",
Expand Down
18 changes: 18 additions & 0 deletions keras/applications/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ py_library(
"mobilenet_v2.py",
"mobilenet_v3.py",
"nasnet.py",
"regnet.py",
"resnet.py",
"resnet_v2.py",
"vgg16.py",
Expand Down Expand Up @@ -312,6 +313,23 @@ tf_py_test(
],
)

tf_py_test(
name = "applications_load_weight_test_regnet",
srcs = ["applications_load_weight_test.py"],
args = ["--module=regnet"],
main = "applications_load_weight_test.py",
tags = [
"no_oss",
"no_pip",
],
deps = [
":applications",
"//:expect_absl_installed",
"//:expect_tensorflow_installed",
"//keras/preprocessing",
],
)

tf_py_test(
name = "applications_load_weight_test_nasnet_mobile",
srcs = ["applications_load_weight_test.py"],
Expand Down
14 changes: 12 additions & 2 deletions keras/applications/applications_load_weight_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from keras.applications import mobilenet_v2
from keras.applications import mobilenet_v3
from keras.applications import nasnet
from keras.applications import regnet
from keras.applications import resnet
from keras.applications import resnet_v2
from keras.applications import vgg16
Expand Down Expand Up @@ -69,7 +70,16 @@
efficientnet_v2.EfficientNetV2B2, efficientnet_v2.EfficientNetV2B3,
efficientnet_v2.EfficientNetV2S, efficientnet_v2.EfficientNetV2M,
efficientnet_v2.EfficientNetV2L
])
]),
'regnet': (regnet,
AdityaKane2001 marked this conversation as resolved.
Show resolved Hide resolved
[regnet.RegNetX002, regnet.RegNetX004, regnet.RegNetX006,
regnet.RegNetX008, regnet.RegNetX016, regnet.RegNetX032,
regnet.RegNetX040, regnet.RegNetX064, regnet.RegNetX080,
regnet.RegNetX120, regnet.RegNetX160, regnet.RegNetX320,
regnet.RegNetY002, regnet.RegNetY004, regnet.RegNetY006,
regnet.RegNetY008, regnet.RegNetY016, regnet.RegNetY032,
regnet.RegNetY040, regnet.RegNetY064, regnet.RegNetY080,
regnet.RegNetY120, regnet.RegNetY160, regnet.RegNetY320])
}

TEST_IMAGE_PATH = ('https://storage.googleapis.com/tensorflow/'
Expand Down Expand Up @@ -115,7 +125,7 @@ def test_application_pretrained_weights_loading(self):
self.assertShapeEqual(model.output_shape, (None, _IMAGENET_CLASSES))
x = _get_elephant(model.input_shape[1:3])
x = app_module.preprocess_input(x)
preds = model.predict(x)
preds = model(x).numpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this needs to be updated? Do things break otherwise? We still run these test cases in TF1 without eager mode enabled, and this line is breaking a number of our tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw

Yes, the change is necessary. Grouped convolutions are not yet fully supported on CPUs. We see that model.predict(X_test) breaks whereas model(X_test) works fine.

There are a number of issues discussing this in the TF repo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could trigger the failures by adding a call to tf.compat.v1.disable_v2_behavior() after before the call to tf.test.main in applications_load_weight_test. We can't submit this if we are breaking all these application tests in a TF1 context. We would need to find a change that does not rely on eager mode behavior (.numpy is eager only).

This might mean we need to dig into the difference between direct call vs predict here. It sound like this is an issue with grouped convolutions on CPU that will only appear when compiling a tf.function, is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw

I have found a small solution to this. I have tested the change both with TF2 and using tf.compat.v1.disable_v2_behavior() and it works on my end. Could you please take a look and run the workflow again?

x = app_module.preprocess_input(x)
try:
preds = model.predict(x) # Works in TF1
except:
preds = model(x).numpy() # Works in TF2
names = [p[1] for p in app_module.decode_predictions(preds)[0]]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look next week! Sorry for the delayed reply, most of the team is out this week. The proposed change would still not submit, because that fallback (the numpy call) would still be run in a TF1 context for the regnet load weights test unless we disable it.

Overall, I think our options are...

  1. Disable the load weights test for regnet (without removing the predict call here), and follow up with a fix.
  2. Fix the underlying CPU/compiled function/grouped convolution issue, and then land this PR.
  3. Work around the bug for regnets somehow (the conversation here suggests that using jit_compile=True may allow CPU to work, which might give us a way forward).

I would say 3) would be the way to go if we can make it work. We really do want the load weights tests to test the compile predicted function (that's how these will often be used!), and shipping regnets such that predict will fail on CPU by default is not a great out of box experience.

Will follow up next week when people are back in office. Thanks!

Copy link
Contributor Author

@AdityaKane2001 AdityaKane2001 Dec 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw

Thanks for the detailed explanation.

I guess (2) is not something which can be done in the Keras codebase, as the error is thrown in tensorflow/tensorflow/core/kernels/conv_ops_fused_impl.h. I'll open an issue in the TF repo regarding this. So I agree that (3) might be the best option.

Lastly, wish you very happy new year!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw

Could you please take a look at this one? TIA

/cc @fchollet @qlzh727

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still here on this! We think we have found a good workaround (option 3), forcing XLA compilation grouped convolutions. #15868

Once that lands (assuming that doesn't run into road blocks), we can submit this without modifying the predict call in load weights tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw

Thanks a lot for this! Really appreciate it.

names = [p[1] for p in app_module.decode_predictions(preds)[0]]
# Test correct label is in top 3 (weak correctness test).
self.assertIn('African_elephant', names[:3])
Expand Down
25 changes: 25 additions & 0 deletions keras/applications/applications_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from keras.applications import mobilenet_v2
from keras.applications import mobilenet_v3
from keras.applications import nasnet
from keras.applications import regnet
from keras.applications import resnet
from keras.applications import resnet_v2
from keras.applications import vgg16
Expand Down Expand Up @@ -69,6 +70,30 @@
(efficientnet_v2.EfficientNetV2S, 1280),
(efficientnet_v2.EfficientNetV2M, 1280),
(efficientnet_v2.EfficientNetV2L, 1280),
(regnet.RegNetX002, 368),
(regnet.RegNetX004, 384),
(regnet.RegNetX006, 528),
(regnet.RegNetX008, 672),
(regnet.RegNetX016, 912),
(regnet.RegNetX032, 1008),
(regnet.RegNetX040, 1360),
(regnet.RegNetX064, 1624),
(regnet.RegNetX080, 1920),
(regnet.RegNetX120, 2240),
(regnet.RegNetX160, 2048),
(regnet.RegNetX320, 2520),
(regnet.RegNetY002, 368),
(regnet.RegNetY004, 440),
(regnet.RegNetY006, 608),
(regnet.RegNetY008, 768),
(regnet.RegNetY016, 888),
(regnet.RegNetY032, 1512),
(regnet.RegNetY040, 1088),
(regnet.RegNetY064, 1296),
(regnet.RegNetY080, 2016),
(regnet.RegNetY120, 2240),
(regnet.RegNetY160, 3024),
(regnet.RegNetY320, 3712)
]

NASNET_LIST = [
Expand Down
Loading