Add test to verify that one can use Sequential with a defun on call.

PiperOrigin-RevId: 207977760
This commit is contained in:
Francois Chollet 2018-08-08 18:05:46 -07:00 committed by TensorFlower Gardener
parent b7172fb01d
commit 96c7397e08
2 changed files with 27 additions and 0 deletions

View File

@ -233,6 +233,9 @@ class Sequential(Model):
return outputs
def _call_and_compute_mask(self, inputs, training=None, mask=None):
if not self.built:
self.build(inputs.shape)
x = inputs
for layer in self.layers:
kwargs = {}

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@ -318,5 +319,28 @@ class TestSequential(test.TestCase, parameterized.TestCase):
[v.name for v in model.variables])
class TestSequentialEagerIntegration(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_defun_on_call(self):
# Check that one can subclass Sequential and place the `call` in a `defun`.
class MySequential(keras.Sequential):
def __init__(self, name=None):
super(MySequential, self).__init__(name=name)
self.call = function.defun(self.call)
model = MySequential()
model.add(keras.layers.Dense(4, activation='relu'))
model.add(keras.layers.Dense(5, activation='softmax'))
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
x = np.random.random((2, 6))
y = np.random.random((2, 5))
model.fit(x, y, epochs=1)
if __name__ == '__main__':
test.main()