diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index cb2b9ed87f1..de67080af66 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -774,6 +774,7 @@ class TrackableWeightHandler(object): if not isinstance(trackable, tracking.Trackable): raise ValueError('%s is not a Trackable object.' % (trackable,)) self._trackable = trackable + self._distribute_strategy = distribution_strategy_context.get_strategy() # TODO(b/141682913): Figure out why this is private and fix it. saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py index 098e67f5f6b..c593cd41c85 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_distribution_test.py @@ -44,7 +44,7 @@ def get_layer_class(): @combinations.generate( combinations.combine( distribution=strategy_combinations.all_strategies, - mode=["eager", "graph"])) + mode=["eager"])) # Eager-only, no graph: b/158793009 class IndexLookupDistributionTest( keras_parameterized.TestCase, preprocessing_test_utils.PreprocessingLayerTest): @@ -74,6 +74,7 @@ class IndexLookupDistributionTest( layer.adapt(vocab_dataset) int_data = layer(input_data) model = keras.Model(inputs=input_data, outputs=int_data) + model.compile(loss="mse") output_dataset = model.predict(input_dataset) self.assertAllEqual(expected_output, output_dataset)