Update bidirectional wrapper tests case that will fail in v1 graph mode in OSS.

Also update the enable_output_all_intermediates to take "self" as args, so that it can work with combinations.

PiperOrigin-RevId: 336109136
Change-Id: I71bfe7737ca24a118d9d1cab8842e840d520ba8e
This commit is contained in:
Scott Zhu 2020-10-08 10:07:01 -07:00 committed by TensorFlower Gardener
parent 3e0aee0d54
commit b8fc0b69e3
2 changed files with 4 additions and 2 deletions

View File

@ -613,12 +613,12 @@ def enable_output_all_intermediates(fn):
The wrapped function
"""
def wrapper(*args, **kwargs):
def wrapper(self, *args, **kwargs):
output_all_intermediates_old = \
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
try:
return fn(*args, **kwargs)
return fn(self, *args, **kwargs)
finally:
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
output_all_intermediates_old

View File

@ -28,6 +28,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
@ -572,6 +573,7 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, epochs=1, batch_size=1)
@test_util.enable_output_all_intermediates
def test_bidirectional_statefulness(self):
# Bidirectional and stateful
rnn = keras.layers.SimpleRNN