Minor logging cleanup, add a few extra tests.
PiperOrigin-RevId: 232395958
This commit is contained in:
parent
038a9fd67b
commit
6f848e3e8b
@ -193,16 +193,17 @@ def converted_call(f, owner, options, *args, **kwargs):
|
|||||||
'Entity {} appears to be decorated by wrapt, which is not yet supported'
|
'Entity {} appears to be decorated by wrapt, which is not yet supported'
|
||||||
' by AutoGraph. The function will be called without transformation.'
|
' by AutoGraph. The function will be called without transformation.'
|
||||||
' You may however apply AutoGraph before the decorator.'.format(f), 1)
|
' You may however apply AutoGraph before the decorator.'.format(f), 1)
|
||||||
|
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
# Other built-in modules are permanently whitelisted.
|
# Other built-in modules are permanently whitelisted.
|
||||||
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
||||||
if (f in collections.__dict__.values() or f in pdb.__dict__.values() or
|
if (f in collections.__dict__.values() or f in pdb.__dict__.values() or
|
||||||
f in copy.__dict__.values()):
|
f in copy.__dict__.values()):
|
||||||
|
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
# TODO(mdan): This needs cleanup.
|
# TODO(mdan): This needs cleanup.
|
||||||
# In particular, we may want to avoid renaming functions altogether.
|
|
||||||
if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
|
if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
|
||||||
|
|
||||||
# TODO(mdan): This may be inconsistent in certain situations.
|
# TODO(mdan): This may be inconsistent in certain situations.
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
@ -26,6 +27,7 @@ import numpy as np
|
|||||||
from tensorflow.python.autograph import utils
|
from tensorflow.python.autograph import utils
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
from tensorflow.python.autograph.impl import api
|
from tensorflow.python.autograph.impl import api
|
||||||
|
from tensorflow.python.autograph.pyct import inspect_utils
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.autograph.utils import py_func
|
from tensorflow.python.autograph.utils import py_func
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
@ -46,7 +48,7 @@ class TestResource(str):
|
|||||||
class ApiTest(test.TestCase):
|
class ApiTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_decorator_recurses(self):
|
def test_decorator_recursive(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
|
||||||
@ -69,7 +71,7 @@ class ApiTest(test.TestCase):
|
|||||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_decorator_does_not_recurse(self):
|
def test_decorator_not_recursive(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
|
||||||
@ -90,7 +92,7 @@ class ApiTest(test.TestCase):
|
|||||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_decorator_calls_unconverted_graph(self):
|
def test_convert_then_do_not_convert_graph(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
|
||||||
@ -105,14 +107,13 @@ class ApiTest(test.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.cached_session() as sess:
|
x = tc.test_method(
|
||||||
x = tc.test_method(
|
constant_op.constant((2, 4)), constant_op.constant(1),
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant(-2))
|
||||||
constant_op.constant(-2))
|
self.assertAllEqual((0, 1), self.evaluate(x))
|
||||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_decorator_calls_unconverted_py_func(self):
|
def test_convert_then_do_not_convert_py_func(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
|
||||||
@ -132,11 +133,10 @@ class ApiTest(test.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
with self.cached_session() as sess:
|
x = tc.test_method(
|
||||||
x = tc.test_method(
|
constant_op.constant((2, 4)), constant_op.constant(1),
|
||||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
constant_op.constant(-2))
|
||||||
constant_op.constant(-2))
|
self.assertAllEqual((0, 1), self.evaluate(x))
|
||||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_decorator_calls_decorated(self):
|
def test_decorator_calls_decorated(self):
|
||||||
@ -265,6 +265,26 @@ class ApiTest(test.TestCase):
|
|||||||
converter.ConversionOptions(), tc)
|
converter.ConversionOptions(), tc)
|
||||||
self.assertEqual(1, self.evaluate(x))
|
self.assertEqual(1, self.evaluate(x))
|
||||||
|
|
||||||
|
def test_converted_call_method_converts_recursively(self):
|
||||||
|
|
||||||
|
class TestClass(object):
|
||||||
|
|
||||||
|
def __init__(self, x):
|
||||||
|
self.x = x
|
||||||
|
|
||||||
|
def other_method(self):
|
||||||
|
if self.x < 0:
|
||||||
|
return -self.x
|
||||||
|
return self.x
|
||||||
|
|
||||||
|
def test_method(self):
|
||||||
|
return self.other_method()
|
||||||
|
|
||||||
|
tc = TestClass(constant_op.constant(-1))
|
||||||
|
x = api.converted_call(tc.test_method, None,
|
||||||
|
converter.ConversionOptions(recursive=True), tc)
|
||||||
|
self.assertEqual(1, self.evaluate(x))
|
||||||
|
|
||||||
def test_converted_call_method_by_class(self):
|
def test_converted_call_method_by_class(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
@ -334,6 +354,22 @@ class ApiTest(test.TestCase):
|
|||||||
constant_op.constant(0))
|
constant_op.constant(0))
|
||||||
self.assertTrue(self.evaluate(x))
|
self.assertTrue(self.evaluate(x))
|
||||||
|
|
||||||
|
def test_converted_call_then_already_converted_dynamic(self):
|
||||||
|
|
||||||
|
@api.convert()
|
||||||
|
def g(x):
|
||||||
|
if x > 0:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return -x
|
||||||
|
|
||||||
|
def f(g, x):
|
||||||
|
return g(x)
|
||||||
|
|
||||||
|
x = api.converted_call(f, None, converter.ConversionOptions(),
|
||||||
|
g, constant_op.constant(1))
|
||||||
|
self.assertEqual(self.evaluate(x), 1)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_converted_call_no_user_code(self):
|
def test_converted_call_no_user_code(self):
|
||||||
|
|
||||||
@ -397,6 +433,24 @@ class ApiTest(test.TestCase):
|
|||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||||
|
|
||||||
|
def test_converted_call_namedtuple(self):
|
||||||
|
|
||||||
|
opts = converter.ConversionOptions()
|
||||||
|
|
||||||
|
x = api.converted_call(collections.namedtuple, None, opts,
|
||||||
|
'TestNamedtuple', ('a', 'b'))
|
||||||
|
|
||||||
|
self.assertTrue(inspect_utils.isnamedtuple(x))
|
||||||
|
|
||||||
|
def test_converted_call_namedtuple_via_collections(self):
|
||||||
|
|
||||||
|
opts = converter.ConversionOptions()
|
||||||
|
|
||||||
|
x = api.converted_call('namedtuple', collections, opts,
|
||||||
|
'TestNamedtuple', ('a', 'b'))
|
||||||
|
|
||||||
|
self.assertTrue(inspect_utils.isnamedtuple(x))
|
||||||
|
|
||||||
def test_converted_call_lambda(self):
|
def test_converted_call_lambda(self):
|
||||||
|
|
||||||
opts = converter.ConversionOptions()
|
opts = converter.ConversionOptions()
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import imp
|
import imp
|
||||||
# import types
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import gast
|
import gast
|
||||||
@ -87,17 +86,17 @@ def is_whitelisted_for_graph(o):
|
|||||||
# Builtins typically have unnamed modules.
|
# Builtins typically have unnamed modules.
|
||||||
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
|
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
|
||||||
if m.__name__.startswith(prefix):
|
if m.__name__.startswith(prefix):
|
||||||
logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix)
|
logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Temporary -- whitelist tensorboard modules.
|
# Temporary -- whitelist tensorboard modules.
|
||||||
# TODO(b/122731813): Remove.
|
# TODO(b/122731813): Remove.
|
||||||
if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
|
if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
|
||||||
logging.log(2, '%s is whitelisted: name contains "tensorboard"', o)
|
logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
|
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
|
||||||
logging.log(2, '%s is whitelisted: already converted', o)
|
logging.log(2, 'Whitelisted: %s: already converted', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if hasattr(o, '__call__'):
|
if hasattr(o, '__call__'):
|
||||||
@ -105,9 +104,10 @@ def is_whitelisted_for_graph(o):
|
|||||||
# The type check avoids infinite recursion around the __call__ method
|
# The type check avoids infinite recursion around the __call__ method
|
||||||
# of function objects.
|
# of function objects.
|
||||||
if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck
|
if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck
|
||||||
logging.log(2, '%s is whitelisted: object __call__ whitelisted', o)
|
logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
owner_class = None
|
||||||
if tf_inspect.ismethod(o):
|
if tf_inspect.ismethod(o):
|
||||||
# Methods of whitelisted classes are also whitelisted, even if they are
|
# Methods of whitelisted classes are also whitelisted, even if they are
|
||||||
# bound via user subclasses.
|
# bound via user subclasses.
|
||||||
@ -127,12 +127,12 @@ def is_whitelisted_for_graph(o):
|
|||||||
owner_class = inspect_utils.getmethodclass(o)
|
owner_class = inspect_utils.getmethodclass(o)
|
||||||
if owner_class is not None:
|
if owner_class is not None:
|
||||||
if issubclass(owner_class, unittest.TestCase):
|
if issubclass(owner_class, unittest.TestCase):
|
||||||
logging.log(2, '%s is whitelisted: method of TestCase subclass', o)
|
logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
owner_class = inspect_utils.getdefiningclass(o, owner_class)
|
owner_class = inspect_utils.getdefiningclass(o, owner_class)
|
||||||
if is_whitelisted_for_graph(owner_class):
|
if is_whitelisted_for_graph(owner_class):
|
||||||
logging.log(2, '%s is whitelisted: owner is whitelisted %s', o,
|
logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
|
||||||
owner_class)
|
owner_class)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -145,10 +145,10 @@ def is_whitelisted_for_graph(o):
|
|||||||
'Entity {} looks like a namedtuple subclass. Its constructor will'
|
'Entity {} looks like a namedtuple subclass. Its constructor will'
|
||||||
' not be converted by AutoGraph, but if it has any custom methods,'
|
' not be converted by AutoGraph, but if it has any custom methods,'
|
||||||
' those will be.'.format(o), 1)
|
' those will be.'.format(o), 1)
|
||||||
logging.log(2, '%s is whitelisted: named tuple', o)
|
logging.log(2, 'Whitelisted: %s: named tuple', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logging.log(2, '%s is NOT whitelisted', o)
|
logging.log(2, 'Not whitelisted: %s: default rule', o)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user