diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 1b37fb4131c..b9e72e66c2e 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -41,13 +41,15 @@ from tensorflow.python.platform import test def whitelist(entity): - if 'test_whitelisted_call' not in sys.modules: - whitelisted_mod = imp.new_module('test_whitelisted_call') - sys.modules['test_whitelisted_call'] = whitelisted_mod - config.CONVERSION_RULES = ((config.DoNotConvert('test_whitelisted_call'),) + - config.CONVERSION_RULES) + """Helper that marks a callable as whtelitisted.""" + if 'whitelisted_module_for_testing' not in sys.modules: + whitelisted_mod = imp.new_module('whitelisted_module_for_testing') + sys.modules['whitelisted_module_for_testing'] = whitelisted_mod + config.CONVERSION_RULES = ( + (config.DoNotConvert('whitelisted_module_for_testing'),) + + config.CONVERSION_RULES) - entity.__module__ = 'test_whitelisted_call' + entity.__module__ = 'whitelisted_module_for_testing' def is_inside_generated_code(): diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index d8f73f20674..146cca2f2eb 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -48,8 +48,6 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util -from tensorflow.python.keras.engine import sequential -from tensorflow.python.keras.layers import core from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -523,7 +521,7 @@ class ApiTest(test.TestCase): ag_logging.set_verbosity(0, False) os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1' - def test_converted_call_partial_of_whitelisted_method(self): + def test_converted_call_partial_of_whitelisted_function(self): def test_fn(_): self.assertFalse(converter_testing.is_inside_generated_code()) @@ -610,25 +608,29 @@ class ApiTest(test.TestCase): def test_converted_call_whitelisted_method(self): - model = sequential.Sequential([core.Dense(2)]) + class TestClass(object): - x = api.converted_call( - model.call, (constant_op.constant([[0.0]]),), {'training': True}, - options=DEFAULT_RECURSIVE) + def method(self): + return converter_testing.is_inside_generated_code() - self.evaluate(variables.global_variables_initializer()) - self.assertAllEqual([[0.0, 0.0]], self.evaluate(x)) + obj = TestClass() + converter_testing.whitelist(obj.method.__func__) + + self.assertFalse( + api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE)) def test_converted_call_whitelisted_method_via_owner(self): - model = sequential.Sequential([core.Dense(2)]) + class TestClass(object): - x = api.converted_call( - model.call, (constant_op.constant([[0.0]]),), {'training': True}, - options=DEFAULT_RECURSIVE) + def method(self): + return converter_testing.is_inside_generated_code() - self.evaluate(variables.global_variables_initializer()) - self.assertAllEqual([[0.0, 0.0]], self.evaluate(x)) + converter_testing.whitelist(TestClass) + + obj = TestClass() + self.assertFalse( + api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE)) def test_converted_call_numpy(self): @@ -1102,11 +1104,21 @@ class ApiTest(test.TestCase): def test_tf_convert_whitelisted_method(self): - model = sequential.Sequential([core.Dense(2)]) + if six.PY2: + self.skipTest('Test bank not comptible with Python 2.') + + class TestClass(object): + + def method(self): + return converter_testing.is_inside_generated_code() + + converter_testing.whitelist(TestClass.method) + + obj = TestClass() converted_call = api.tf_convert( - model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED)) + obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED)) _, converted_target = tf_decorator.unwrap(converted_call) - self.assertIs(converted_target.__func__, model.call.__func__) + self.assertIs(converted_target.__func__, obj.method.__func__) def test_tf_convert_tf_decorator_unwrapping_context_enabled(self):