Decouple unit tests from Keras.

PiperOrigin-RevId: 306032348
Change-Id: Ib938318fc2707e204eab648f912250f4e0c49c8c
This commit is contained in:
Dan Moldovan 2020-04-11 06:59:46 -07:00 committed by TensorFlower Gardener
parent 65da70ab55
commit 20a26f65d0
2 changed files with 38 additions and 24 deletions
tensorflow/python/autograph

View File

@ -41,13 +41,15 @@ from tensorflow.python.platform import test
def whitelist(entity):
if 'test_whitelisted_call' not in sys.modules:
whitelisted_mod = imp.new_module('test_whitelisted_call')
sys.modules['test_whitelisted_call'] = whitelisted_mod
config.CONVERSION_RULES = ((config.DoNotConvert('test_whitelisted_call'),) +
config.CONVERSION_RULES)
"""Helper that marks a callable as whtelitisted."""
if 'whitelisted_module_for_testing' not in sys.modules:
whitelisted_mod = imp.new_module('whitelisted_module_for_testing')
sys.modules['whitelisted_module_for_testing'] = whitelisted_mod
config.CONVERSION_RULES = (
(config.DoNotConvert('whitelisted_module_for_testing'),) +
config.CONVERSION_RULES)
entity.__module__ = 'test_whitelisted_call'
entity.__module__ = 'whitelisted_module_for_testing'
def is_inside_generated_code():

View File

@ -48,8 +48,6 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -523,7 +521,7 @@ class ApiTest(test.TestCase):
ag_logging.set_verbosity(0, False)
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
def test_converted_call_partial_of_whitelisted_method(self):
def test_converted_call_partial_of_whitelisted_function(self):
def test_fn(_):
self.assertFalse(converter_testing.is_inside_generated_code())
@ -610,25 +608,29 @@ class ApiTest(test.TestCase):
def test_converted_call_whitelisted_method(self):
model = sequential.Sequential([core.Dense(2)])
class TestClass(object):
x = api.converted_call(
model.call, (constant_op.constant([[0.0]]),), {'training': True},
options=DEFAULT_RECURSIVE)
def method(self):
return converter_testing.is_inside_generated_code()
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
obj = TestClass()
converter_testing.whitelist(obj.method.__func__)
self.assertFalse(
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
def test_converted_call_whitelisted_method_via_owner(self):
model = sequential.Sequential([core.Dense(2)])
class TestClass(object):
x = api.converted_call(
model.call, (constant_op.constant([[0.0]]),), {'training': True},
options=DEFAULT_RECURSIVE)
def method(self):
return converter_testing.is_inside_generated_code()
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
converter_testing.whitelist(TestClass)
obj = TestClass()
self.assertFalse(
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
def test_converted_call_numpy(self):
@ -1102,11 +1104,21 @@ class ApiTest(test.TestCase):
def test_tf_convert_whitelisted_method(self):
model = sequential.Sequential([core.Dense(2)])
if six.PY2:
self.skipTest('Test bank not comptible with Python 2.')
class TestClass(object):
def method(self):
return converter_testing.is_inside_generated_code()
converter_testing.whitelist(TestClass.method)
obj = TestClass()
converted_call = api.tf_convert(
model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
_, converted_target = tf_decorator.unwrap(converted_call)
self.assertIs(converted_target.__func__, model.call.__func__)
self.assertIs(converted_target.__func__, obj.method.__func__)
def test_tf_convert_tf_decorator_unwrapping_context_enabled(self):