Prepare cudnn_recurrent_test for Tensor equality.
PiperOrigin-RevId: 263449689
(cherry picked from commit ff61cee968
)
This commit is contained in:
parent
97071c535c
commit
48fa7d9881
@ -139,7 +139,9 @@ class CuDNNTest(keras_parameterized.TestCase):
|
||||
output = layer(inputs, initial_state=initial_state[0])
|
||||
else:
|
||||
output = layer(inputs, initial_state=initial_state)
|
||||
self.assertIn(initial_state[0], layer._inbound_nodes[0].input_tensors)
|
||||
self.assertTrue(
|
||||
any(initial_state[0] is t
|
||||
for t in layer._inbound_nodes[0].input_tensors))
|
||||
|
||||
model = keras.models.Model([inputs] + initial_state, output)
|
||||
model.compile(
|
||||
|
Loading…
Reference in New Issue
Block a user