Update bidirectional wrapper tests case that will fail in v1 graph mode in OSS.
Also update the enable_output_all_intermediates to take "self" as args, so that it can work with combinations. PiperOrigin-RevId: 336109136 Change-Id: I71bfe7737ca24a118d9d1cab8842e840d520ba8e
This commit is contained in:
parent
3e0aee0d54
commit
b8fc0b69e3
@ -613,12 +613,12 @@ def enable_output_all_intermediates(fn):
|
||||
The wrapped function
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
output_all_intermediates_old = \
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
return fn(self, *args, **kwargs)
|
||||
finally:
|
||||
control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
|
||||
output_all_intermediates_old
|
||||
|
||||
@ -28,6 +28,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
@ -572,6 +573,7 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
|
||||
model.compile(loss='mse', optimizer='sgd')
|
||||
model.fit(x, y, epochs=1, batch_size=1)
|
||||
|
||||
@test_util.enable_output_all_intermediates
|
||||
def test_bidirectional_statefulness(self):
|
||||
# Bidirectional and stateful
|
||||
rnn = keras.layers.SimpleRNN
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user