Prepare cudnn_recurrent_test for Tensor equality.

PiperOrigin-RevId: 263449689
(cherry picked from commit ff61cee968)
This commit is contained in:
Saurabh Saxena 2019-08-14 16:03:32 -07:00 committed by Gaurav Jain
parent 97071c535c
commit 48fa7d9881

View File

@ -139,7 +139,9 @@ class CuDNNTest(keras_parameterized.TestCase):
output = layer(inputs, initial_state=initial_state[0])
else:
output = layer(inputs, initial_state=initial_state)
self.assertIn(initial_state[0], layer._inbound_nodes[0].input_tensors)
self.assertTrue(
any(initial_state[0] is t
for t in layer._inbound_nodes[0].input_tensors))
model = keras.models.Model([inputs] + initial_state, output)
model.compile(