From 426f62af5eb80e5f0c3b660451dac6f953b4ca0c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 Jun 2020 07:36:58 -0700 Subject: [PATCH] Add an explicit distribution_strategy attribute to TrackableHandler. Since it's treated as a weight by Keras, it needs to pass DistStrat checks during compile. PiperOrigin-RevId: 316678228 Change-Id: I132168f1ca3dd3729d7a499cef3564c5e04abb34 --- tensorflow/python/keras/engine/base_layer_utils.py | 1 + .../layers/preprocessing/index_lookup_distribution_test.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index cb2b9ed87f1..de67080af66 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -774,6 +774,7 @@ class TrackableWeightHandler(object): if not isinstance(trackable, tracking.Trackable): raise ValueError('%s is not a Trackable object.' % (trackable,)) self._trackable = trackable + self._distribute_strategy = distribution_strategy_context.get_strategy() # TODO(b/141682913): Figure out why this is private and fix it. saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py index 098e67f5f6b..c593cd41c85 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py @@ -44,7 +44,7 @@ def get_layer_class(): @combinations.generate( combinations.combine( distribution=strategy_combinations.all_strategies, - mode=["eager", "graph"])) + mode=["eager"])) # Eager-only, no graph: b/158793009 class IndexLookupDistributionTest( keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest): @@ -74,6 +74,7 @@ class IndexLookupDistributionTest( layer.adapt(vocab_dataset) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) + model.compile(loss="mse") output_dataset = model.predict(input_dataset) self.assertAllEqual(expected_output, output_dataset)