Decouple unit tests from Keras.
PiperOrigin-RevId: 306032348 Change-Id: Ib938318fc2707e204eab648f912250f4e0c49c8c
This commit is contained in:
parent
65da70ab55
commit
20a26f65d0
@ -41,13 +41,15 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
|
|
||||||
def whitelist(entity):
|
def whitelist(entity):
|
||||||
if 'test_whitelisted_call' not in sys.modules:
|
"""Helper that marks a callable as whtelitisted."""
|
||||||
whitelisted_mod = imp.new_module('test_whitelisted_call')
|
if 'whitelisted_module_for_testing' not in sys.modules:
|
||||||
sys.modules['test_whitelisted_call'] = whitelisted_mod
|
whitelisted_mod = imp.new_module('whitelisted_module_for_testing')
|
||||||
config.CONVERSION_RULES = ((config.DoNotConvert('test_whitelisted_call'),) +
|
sys.modules['whitelisted_module_for_testing'] = whitelisted_mod
|
||||||
config.CONVERSION_RULES)
|
config.CONVERSION_RULES = (
|
||||||
|
(config.DoNotConvert('whitelisted_module_for_testing'),) +
|
||||||
|
config.CONVERSION_RULES)
|
||||||
|
|
||||||
entity.__module__ = 'test_whitelisted_call'
|
entity.__module__ = 'whitelisted_module_for_testing'
|
||||||
|
|
||||||
|
|
||||||
def is_inside_generated_code():
|
def is_inside_generated_code():
|
||||||
|
@ -48,8 +48,6 @@ from tensorflow.python.eager import def_function
|
|||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.engine import sequential
|
|
||||||
from tensorflow.python.keras.layers import core
|
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -523,7 +521,7 @@ class ApiTest(test.TestCase):
|
|||||||
ag_logging.set_verbosity(0, False)
|
ag_logging.set_verbosity(0, False)
|
||||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
||||||
|
|
||||||
def test_converted_call_partial_of_whitelisted_method(self):
|
def test_converted_call_partial_of_whitelisted_function(self):
|
||||||
|
|
||||||
def test_fn(_):
|
def test_fn(_):
|
||||||
self.assertFalse(converter_testing.is_inside_generated_code())
|
self.assertFalse(converter_testing.is_inside_generated_code())
|
||||||
@ -610,25 +608,29 @@ class ApiTest(test.TestCase):
|
|||||||
|
|
||||||
def test_converted_call_whitelisted_method(self):
|
def test_converted_call_whitelisted_method(self):
|
||||||
|
|
||||||
model = sequential.Sequential([core.Dense(2)])
|
class TestClass(object):
|
||||||
|
|
||||||
x = api.converted_call(
|
def method(self):
|
||||||
model.call, (constant_op.constant([[0.0]]),), {'training': True},
|
return converter_testing.is_inside_generated_code()
|
||||||
options=DEFAULT_RECURSIVE)
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
obj = TestClass()
|
||||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
converter_testing.whitelist(obj.method.__func__)
|
||||||
|
|
||||||
|
self.assertFalse(
|
||||||
|
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
|
||||||
|
|
||||||
def test_converted_call_whitelisted_method_via_owner(self):
|
def test_converted_call_whitelisted_method_via_owner(self):
|
||||||
|
|
||||||
model = sequential.Sequential([core.Dense(2)])
|
class TestClass(object):
|
||||||
|
|
||||||
x = api.converted_call(
|
def method(self):
|
||||||
model.call, (constant_op.constant([[0.0]]),), {'training': True},
|
return converter_testing.is_inside_generated_code()
|
||||||
options=DEFAULT_RECURSIVE)
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
converter_testing.whitelist(TestClass)
|
||||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
|
||||||
|
obj = TestClass()
|
||||||
|
self.assertFalse(
|
||||||
|
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
|
||||||
|
|
||||||
def test_converted_call_numpy(self):
|
def test_converted_call_numpy(self):
|
||||||
|
|
||||||
@ -1102,11 +1104,21 @@ class ApiTest(test.TestCase):
|
|||||||
|
|
||||||
def test_tf_convert_whitelisted_method(self):
|
def test_tf_convert_whitelisted_method(self):
|
||||||
|
|
||||||
model = sequential.Sequential([core.Dense(2)])
|
if six.PY2:
|
||||||
|
self.skipTest('Test bank not comptible with Python 2.')
|
||||||
|
|
||||||
|
class TestClass(object):
|
||||||
|
|
||||||
|
def method(self):
|
||||||
|
return converter_testing.is_inside_generated_code()
|
||||||
|
|
||||||
|
converter_testing.whitelist(TestClass.method)
|
||||||
|
|
||||||
|
obj = TestClass()
|
||||||
converted_call = api.tf_convert(
|
converted_call = api.tf_convert(
|
||||||
model.call, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
|
obj.method, ag_ctx.ControlStatusCtx(status=ag_ctx.Status.ENABLED))
|
||||||
_, converted_target = tf_decorator.unwrap(converted_call)
|
_, converted_target = tf_decorator.unwrap(converted_call)
|
||||||
self.assertIs(converted_target.__func__, model.call.__func__)
|
self.assertIs(converted_target.__func__, obj.method.__func__)
|
||||||
|
|
||||||
def test_tf_convert_tf_decorator_unwrapping_context_enabled(self):
|
def test_tf_convert_tf_decorator_unwrapping_context_enabled(self):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user