Add an explicit distribution_strategy attribute to TrackableHandler. Since it's treated as a weight by Keras, it needs to pass DistStrat checks during compile.

PiperOrigin-RevId: 316678228
Change-Id: I132168f1ca3dd3729d7a499cef3564c5e04abb34
This commit is contained in:
A. Unique TensorFlower 2020-06-16 07:36:58 -07:00 committed by TensorFlower Gardener
parent 770251a700
commit 426f62af5e
2 changed files with 3 additions and 1 deletions

View File

@ -774,6 +774,7 @@ class TrackableWeightHandler(object):
if not isinstance(trackable, tracking.Trackable): if not isinstance(trackable, tracking.Trackable):
raise ValueError('%s is not a Trackable object.' % (trackable,)) raise ValueError('%s is not a Trackable object.' % (trackable,))
self._trackable = trackable self._trackable = trackable
self._distribute_strategy = distribution_strategy_context.get_strategy()
# TODO(b/141682913): Figure out why this is private and fix it. # TODO(b/141682913): Figure out why this is private and fix it.
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access

View File

@ -44,7 +44,7 @@ def get_layer_class():
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=strategy_combinations.all_strategies, distribution=strategy_combinations.all_strategies,
mode=["eager", "graph"])) mode=["eager"])) # Eager-only, no graph: b/158793009
class IndexLookupDistributionTest( class IndexLookupDistributionTest(
keras_parameterized.TestCase, keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest): preprocessing_test_utils.PreprocessingLayerTest):
@ -74,6 +74,7 @@ class IndexLookupDistributionTest(
layer.adapt(vocab_dataset) layer.adapt(vocab_dataset)
int_data = layer(input_data) int_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=int_data) model = keras.Model(inputs=input_data, outputs=int_data)
model.compile(loss="mse")
output_dataset = model.predict(input_dataset) output_dataset = model.predict(input_dataset)
self.assertAllEqual(expected_output, output_dataset) self.assertAllEqual(expected_output, output_dataset)