Add test to verify that one can use Sequential with a defun on call
.
PiperOrigin-RevId: 207977760
This commit is contained in:
parent
b7172fb01d
commit
96c7397e08
@ -233,6 +233,9 @@ class Sequential(Model):
|
||||
return outputs
|
||||
|
||||
def _call_and_compute_mask(self, inputs, training=None, mask=None):
|
||||
if not self.built:
|
||||
self.build(inputs.shape)
|
||||
|
||||
x = inputs
|
||||
for layer in self.layers:
|
||||
kwargs = {}
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -318,5 +319,28 @@ class TestSequential(test.TestCase, parameterized.TestCase):
|
||||
[v.name for v in model.variables])
|
||||
|
||||
|
||||
class TestSequentialEagerIntegration(test.TestCase):
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes
|
||||
def test_defun_on_call(self):
|
||||
# Check that one can subclass Sequential and place the `call` in a `defun`.
|
||||
|
||||
class MySequential(keras.Sequential):
|
||||
|
||||
def __init__(self, name=None):
|
||||
super(MySequential, self).__init__(name=name)
|
||||
self.call = function.defun(self.call)
|
||||
|
||||
model = MySequential()
|
||||
model.add(keras.layers.Dense(4, activation='relu'))
|
||||
model.add(keras.layers.Dense(5, activation='softmax'))
|
||||
|
||||
model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
|
||||
|
||||
x = np.random.random((2, 6))
|
||||
y = np.random.random((2, 5))
|
||||
model.fit(x, y, epochs=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user