Add an explicit distribution_strategy attribute to TrackableHandler. Since it's treated as a weight by Keras, it needs to pass DistStrat checks during compile.
PiperOrigin-RevId: 316678228 Change-Id: I132168f1ca3dd3729d7a499cef3564c5e04abb34
This commit is contained in:
parent
770251a700
commit
426f62af5e
|
@ -774,6 +774,7 @@ class TrackableWeightHandler(object):
|
||||||
if not isinstance(trackable, tracking.Trackable):
|
if not isinstance(trackable, tracking.Trackable):
|
||||||
raise ValueError('%s is not a Trackable object.' % (trackable,))
|
raise ValueError('%s is not a Trackable object.' % (trackable,))
|
||||||
self._trackable = trackable
|
self._trackable = trackable
|
||||||
|
self._distribute_strategy = distribution_strategy_context.get_strategy()
|
||||||
|
|
||||||
# TODO(b/141682913): Figure out why this is private and fix it.
|
# TODO(b/141682913): Figure out why this is private and fix it.
|
||||||
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
|
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
|
||||||
|
|
|
@ -44,7 +44,7 @@ def get_layer_class():
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
distribution=strategy_combinations.all_strategies,
|
distribution=strategy_combinations.all_strategies,
|
||||||
mode=["eager", "graph"]))
|
mode=["eager"])) # Eager-only, no graph: b/158793009
|
||||||
class IndexLookupDistributionTest(
|
class IndexLookupDistributionTest(
|
||||||
keras_parameterized.TestCase,
|
keras_parameterized.TestCase,
|
||||||
preprocessing_test_utils.PreprocessingLayerTest):
|
preprocessing_test_utils.PreprocessingLayerTest):
|
||||||
|
@ -74,6 +74,7 @@ class IndexLookupDistributionTest(
|
||||||
layer.adapt(vocab_dataset)
|
layer.adapt(vocab_dataset)
|
||||||
int_data = layer(input_data)
|
int_data = layer(input_data)
|
||||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||||
|
model.compile(loss="mse")
|
||||||
output_dataset = model.predict(input_dataset)
|
output_dataset = model.predict(input_dataset)
|
||||||
self.assertAllEqual(expected_output, output_dataset)
|
self.assertAllEqual(expected_output, output_dataset)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue