From b8fc0b69e36ca5bb626fea30e96e7bf8ddbe47ec Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 8 Oct 2020 10:07:01 -0700 Subject: [PATCH] Update bidirectional wrapper tests case that will fail in v1 graph mode in OSS. Also update the enable_output_all_intermediates to take "self" as args, so that it can work with combinations. PiperOrigin-RevId: 336109136 Change-Id: I71bfe7737ca24a118d9d1cab8842e840d520ba8e --- tensorflow/python/framework/test_util.py | 4 ++-- tensorflow/python/keras/layers/wrappers_test.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 7f610393180..b6589bc9bd7 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -613,12 +613,12 @@ def enable_output_all_intermediates(fn): The wrapped function """ - def wrapper(*args, **kwargs): + def wrapper(self, *args, **kwargs): output_all_intermediates_old = \ control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True try: - return fn(*args, **kwargs) + return fn(self, *args, **kwargs) finally: control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \ output_all_intermediates_old diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index f1412975cc3..15b672455bc 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -28,6 +28,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils @@ -572,6 +573,7 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): model.compile(loss='mse', optimizer='sgd') model.fit(x, y, epochs=1, batch_size=1) + @test_util.enable_output_all_intermediates def test_bidirectional_statefulness(self): # Bidirectional and stateful rnn = keras.layers.SimpleRNN