Decouple unit tests from Keras.
PiperOrigin-RevId: 306032348 Change-Id: Ib938318fc2707e204eab648f912250f4e0c49c8c
This commit is contained in:
parent
65da70ab55
commit
20a26f65d0
tensorflow/python/autograph
@ -41,13 +41,15 @@ from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def whitelist(entity):
|
||||
if 'test_whitelisted_call' not in sys.modules:
|
||||
whitelisted_mod = imp.new_module('test_whitelisted_call')
|
||||
sys.modules['test_whitelisted_call'] = whitelisted_mod
|
||||
config.CONVERSION_RULES = ((config.DoNotConvert('test_whitelisted_call'),) +
|
||||
config.CONVERSION_RULES)
|
||||
"""Helper that marks a callable as whtelitisted."""
|
||||
if 'whitelisted_module_for_testing' not in sys.modules:
|
||||
whitelisted_mod = imp.new_module('whitelisted_module_for_testing')
|
||||
sys.modules['whitelisted_module_for_testing'] = whitelisted_mod
|
||||
config.CONVERSION_RULES = (
|
||||
(config.DoNotConvert('whitelisted_module_for_testing'),) +
|
||||
config.CONVERSION_RULES)
|
||||
|
||||
entity.__module__ = 'test_whitelisted_call'
|
||||
entity.__module__ = 'whitelisted_module_for_testing'
|
||||
|
||||
|
||||
def is_inside_generated_code():
|
||||
|
@ -48,8 +48,6 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -523,7 +521,7 @@ class ApiTest(test.TestCase):
|
||||
ag_logging.set_verbosity(0, False)
|
||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
||||
|
||||
def test_converted_call_partial_of_whitelisted_method(self):
|
||||
def test_converted_call_partial_of_whitelisted_function(self):
|
||||
|
||||
def test_fn(_):
|
||||
self.assertFalse(converter_testing.is_inside_generated_code())
|
||||
@ -610,25 +608,29 @@ class ApiTest(test.TestCase):
|
||||
|
||||
def test_converted_call_whitelisted_method(self):
|
||||
|
||||
model = sequential.Sequential([core.Dense(2)])
|
||||
class TestClass(object):
|
||||
|
||||
x = api.converted_call(
|
||||
model.call, (constant_op.constant([[0.0]]),), {'training': True},
|
||||
options=DEFAULT_RECURSIVE)
|
||||
def method(self):
|
||||
return converter_testing.is_inside_generated_code()
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
obj = TestClass()
|
||||
converter_testing.whitelist(obj.method.__func__)
|
||||
|
||||
self.assertFalse(
|
||||
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
|
||||
|
||||
def test_converted_call_whitelisted_method_via_owner(self):
|
||||
|
||||
model = sequential.Sequential([core.Dense(2)])
|
||||
class TestClass(object):
|
||||
|
||||
x = api.converted_call(
|
||||
model.call, (constant_op.constant([[0.0]]),), {'training': True},
|
||||
options=DEFAULT_RECURSIVE)
|
||||
def method(self):
|
||||
return converter_testing.is_inside_generated_code()
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
converter_testing.whitelist(TestClass)
|
||||
|
||||
obj = TestClass()
|
||||
self.assertFalse(
|
||||
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
|
||||
|
||||
def test_converted_call_numpy(self):
|
||||
|
||||
@ -1102,11 +1104,21 @@ class ApiTest(test.TestCase):
|
||||
|
||||
def test_tf_convert_whitelisted_method(self):
|
||||
|
||||
model = sequential.Sequential([core.Dense(2)])
|
||||
if six.PY2:
|
||||
self.skipTest('Test bank not comptible with Python 2.')
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def method(self):
|
||||
return converter_testing.is_inside_generated_code()
|
||||
|
||||
converter_testing.whitelist(TestClass.method)
|
||||
|
||||
obj = TestClass()
|
||||
converted_call = api.tf_convert(
|
||||
model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
|
||||
obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
|
||||
_, converted_target = tf_decorator.unwrap(converted_call)
|
||||
self.assertIs(converted_target.__func__, model.call.__func__)
|
||||
self.assertIs(converted_target.__func__, obj.method.__func__)
|
||||
|
||||
def test_tf_convert_tf_decorator_unwrapping_context_enabled(self):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user