Minor logging cleanup, add a few extra tests.
PiperOrigin-RevId: 232395958
This commit is contained in:
parent
038a9fd67b
commit
6f848e3e8b
@ -193,16 +193,17 @@ def converted_call(f, owner, options, *args, **kwargs):
|
||||
'Entity {} appears to be decorated by wrapt, which is not yet supported'
|
||||
' by AutoGraph. The function will be called without transformation.'
|
||||
' You may however apply AutoGraph before the decorator.'.format(f), 1)
|
||||
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
# Other built-in modules are permanently whitelisted.
|
||||
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
||||
if (f in collections.__dict__.values() or f in pdb.__dict__.values() or
|
||||
f in copy.__dict__.values()):
|
||||
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
# TODO(mdan): This needs cleanup.
|
||||
# In particular, we may want to avoid renaming functions altogether.
|
||||
if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
|
||||
|
||||
# TODO(mdan): This may be inconsistent in certain situations.
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import gc
|
||||
|
||||
@ -26,6 +27,7 @@ import numpy as np
|
||||
from tensorflow.python.autograph import utils
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.impl import api
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.utils import py_func
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -46,7 +48,7 @@ class TestResource(str):
|
||||
class ApiTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_decorator_recurses(self):
|
||||
def test_decorator_recursive(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
@ -69,7 +71,7 @@ class ApiTest(test.TestCase):
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_decorator_does_not_recurse(self):
|
||||
def test_decorator_not_recursive(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
@ -90,7 +92,7 @@ class ApiTest(test.TestCase):
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_decorator_calls_unconverted_graph(self):
|
||||
def test_convert_then_do_not_convert_graph(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
@ -105,14 +107,13 @@ class ApiTest(test.TestCase):
|
||||
return x
|
||||
|
||||
tc = TestClass()
|
||||
with self.cached_session() as sess:
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
x = tc.test_method(
|
||||
constant_op.constant((2, 4)), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertAllEqual((0, 1), self.evaluate(x))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_decorator_calls_unconverted_py_func(self):
|
||||
def test_convert_then_do_not_convert_py_func(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
@ -132,11 +133,10 @@ class ApiTest(test.TestCase):
|
||||
return x
|
||||
|
||||
tc = TestClass()
|
||||
with self.cached_session() as sess:
|
||||
x = tc.test_method(
|
||||
constant_op.constant([2, 4]), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertListEqual([0, 1], self.evaluate(x).tolist())
|
||||
x = tc.test_method(
|
||||
constant_op.constant((2, 4)), constant_op.constant(1),
|
||||
constant_op.constant(-2))
|
||||
self.assertAllEqual((0, 1), self.evaluate(x))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_decorator_calls_decorated(self):
|
||||
@ -265,6 +265,26 @@ class ApiTest(test.TestCase):
|
||||
converter.ConversionOptions(), tc)
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_method_converts_recursively(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def other_method(self):
|
||||
if self.x < 0:
|
||||
return -self.x
|
||||
return self.x
|
||||
|
||||
def test_method(self):
|
||||
return self.other_method()
|
||||
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
x = api.converted_call(tc.test_method, None,
|
||||
converter.ConversionOptions(recursive=True), tc)
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_method_by_class(self):
|
||||
|
||||
class TestClass(object):
|
||||
@ -334,6 +354,22 @@ class ApiTest(test.TestCase):
|
||||
constant_op.constant(0))
|
||||
self.assertTrue(self.evaluate(x))
|
||||
|
||||
def test_converted_call_then_already_converted_dynamic(self):
|
||||
|
||||
@api.convert()
|
||||
def g(x):
|
||||
if x > 0:
|
||||
return x
|
||||
else:
|
||||
return -x
|
||||
|
||||
def f(g, x):
|
||||
return g(x)
|
||||
|
||||
x = api.converted_call(f, None, converter.ConversionOptions(),
|
||||
g, constant_op.constant(1))
|
||||
self.assertEqual(self.evaluate(x), 1)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_converted_call_no_user_code(self):
|
||||
|
||||
@ -397,6 +433,24 @@ class ApiTest(test.TestCase):
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
|
||||
def test_converted_call_namedtuple(self):
|
||||
|
||||
opts = converter.ConversionOptions()
|
||||
|
||||
x = api.converted_call(collections.namedtuple, None, opts,
|
||||
'TestNamedtuple', ('a', 'b'))
|
||||
|
||||
self.assertTrue(inspect_utils.isnamedtuple(x))
|
||||
|
||||
def test_converted_call_namedtuple_via_collections(self):
|
||||
|
||||
opts = converter.ConversionOptions()
|
||||
|
||||
x = api.converted_call('namedtuple', collections, opts,
|
||||
'TestNamedtuple', ('a', 'b'))
|
||||
|
||||
self.assertTrue(inspect_utils.isnamedtuple(x))
|
||||
|
||||
def test_converted_call_lambda(self):
|
||||
|
||||
opts = converter.ConversionOptions()
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import imp
|
||||
# import types
|
||||
import unittest
|
||||
|
||||
import gast
|
||||
@ -87,17 +86,17 @@ def is_whitelisted_for_graph(o):
|
||||
# Builtins typically have unnamed modules.
|
||||
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
|
||||
if m.__name__.startswith(prefix):
|
||||
logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix)
|
||||
logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
|
||||
return True
|
||||
|
||||
# Temporary -- whitelist tensorboard modules.
|
||||
# TODO(b/122731813): Remove.
|
||||
if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
|
||||
logging.log(2, '%s is whitelisted: name contains "tensorboard"', o)
|
||||
logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
|
||||
return True
|
||||
|
||||
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
|
||||
logging.log(2, '%s is whitelisted: already converted', o)
|
||||
logging.log(2, 'Whitelisted: %s: already converted', o)
|
||||
return True
|
||||
|
||||
if hasattr(o, '__call__'):
|
||||
@ -105,9 +104,10 @@ def is_whitelisted_for_graph(o):
|
||||
# The type check avoids infinite recursion around the __call__ method
|
||||
# of function objects.
|
||||
if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck
|
||||
logging.log(2, '%s is whitelisted: object __call__ whitelisted', o)
|
||||
logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
|
||||
return True
|
||||
|
||||
owner_class = None
|
||||
if tf_inspect.ismethod(o):
|
||||
# Methods of whitelisted classes are also whitelisted, even if they are
|
||||
# bound via user subclasses.
|
||||
@ -127,12 +127,12 @@ def is_whitelisted_for_graph(o):
|
||||
owner_class = inspect_utils.getmethodclass(o)
|
||||
if owner_class is not None:
|
||||
if issubclass(owner_class, unittest.TestCase):
|
||||
logging.log(2, '%s is whitelisted: method of TestCase subclass', o)
|
||||
logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
|
||||
return True
|
||||
|
||||
owner_class = inspect_utils.getdefiningclass(o, owner_class)
|
||||
if is_whitelisted_for_graph(owner_class):
|
||||
logging.log(2, '%s is whitelisted: owner is whitelisted %s', o,
|
||||
logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
|
||||
owner_class)
|
||||
return True
|
||||
|
||||
@ -145,10 +145,10 @@ def is_whitelisted_for_graph(o):
|
||||
'Entity {} looks like a namedtuple subclass. Its constructor will'
|
||||
' not be converted by AutoGraph, but if it has any custom methods,'
|
||||
' those will be.'.format(o), 1)
|
||||
logging.log(2, '%s is whitelisted: named tuple', o)
|
||||
logging.log(2, 'Whitelisted: %s: named tuple', o)
|
||||
return True
|
||||
|
||||
logging.log(2, '%s is NOT whitelisted', o)
|
||||
logging.log(2, 'Not whitelisted: %s: default rule', o)
|
||||
return False
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user