Add an explicit distribution_strategy attribute to TrackableHandler. Since it's treated as a weight by Keras, it needs to pass DistStrat checks during compile.
PiperOrigin-RevId: 316678228 Change-Id: I132168f1ca3dd3729d7a499cef3564c5e04abb34
This commit is contained in:
parent
770251a700
commit
426f62af5e
|
@ -774,6 +774,7 @@ class TrackableWeightHandler(object):
|
|||
if not isinstance(trackable, tracking.Trackable):
|
||||
raise ValueError('%s is not a Trackable object.' % (trackable,))
|
||||
self._trackable = trackable
|
||||
self._distribute_strategy = distribution_strategy_context.get_strategy()
|
||||
|
||||
# TODO(b/141682913): Figure out why this is private and fix it.
|
||||
saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access
|
||||
|
|
|
@ -44,7 +44,7 @@ def get_layer_class():
|
|||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager", "graph"]))
|
||||
mode=["eager"])) # Eager-only, no graph: b/158793009
|
||||
class IndexLookupDistributionTest(
|
||||
keras_parameterized.TestCase,
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
@ -74,6 +74,7 @@ class IndexLookupDistributionTest(
|
|||
layer.adapt(vocab_dataset)
|
||||
int_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||
model.compile(loss="mse")
|
||||
output_dataset = model.predict(input_dataset)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
|
|
Loading…
Reference in New Issue