Fix object-based checkpoint dependencies for Keras Wrapper objects.
PiperOrigin-RevId: 201424910
This commit is contained in:
parent
cbbffe5f64
commit
1f4a7264c8
@ -46,6 +46,7 @@ class Wrapper(Layer):
|
||||
|
||||
def __init__(self, layer, **kwargs):
|
||||
self.layer = layer
|
||||
self._track_checkpointable(layer, name='layer')
|
||||
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
|
||||
# the inner layer has update ops that depend on its inputs (as opposed
|
||||
# to the inputs to the Wrapper layer).
|
||||
|
@ -25,6 +25,7 @@ import numpy as np
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_util
|
||||
from tensorflow.python.training.rmsprop import RMSPropOptimizer
|
||||
|
||||
|
||||
@ -85,6 +86,10 @@ class TimeDistributedTest(test.TestCase):
|
||||
# test config
|
||||
model.get_config()
|
||||
|
||||
checkpointed_objects = set(checkpointable_util.list_objects(model))
|
||||
for v in model.variables:
|
||||
self.assertIn(v, checkpointed_objects)
|
||||
|
||||
def test_timedistributed_static_batch_size(self):
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
|
Loading…
x
Reference in New Issue
Block a user