Fix object-based checkpoint dependencies for Keras Wrapper objects.

PiperOrigin-RevId: 201424910
This commit is contained in:
Allen Lavoie 2018-06-20 15:11:59 -07:00 committed by TensorFlower Gardener
parent cbbffe5f64
commit 1f4a7264c8
2 changed files with 6 additions and 0 deletions

View File

@ -46,6 +46,7 @@ class Wrapper(Layer):
def __init__(self, layer, **kwargs):
self.layer = layer
self._track_checkpointable(layer, name='layer')
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
# the inner layer has update ops that depend on its inputs (as opposed
# to the inputs to the Wrapper layer).

View File

@ -25,6 +25,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.training.rmsprop import RMSPropOptimizer
@ -85,6 +86,10 @@ class TimeDistributedTest(test.TestCase):
# test config
model.get_config()
checkpointed_objects = set(checkpointable_util.list_objects(model))
for v in model.variables:
self.assertIn(v, checkpointed_objects)
def test_timedistributed_static_batch_size(self):
model = keras.models.Sequential()
model.add(