Minor logging cleanup, add a few extra tests.

PiperOrigin-RevId: 232395958
This commit is contained in:
Dan Moldovan 2019-02-04 17:12:27 -08:00 committed by TensorFlower Gardener
parent 038a9fd67b
commit 6f848e3e8b
3 changed files with 79 additions and 24 deletions

View File

@ -193,16 +193,17 @@ def converted_call(f, owner, options, *args, **kwargs):
'Entity {} appears to be decorated by wrapt, which is not yet supported'
' by AutoGraph. The function will be called without transformation.'
' You may however apply AutoGraph before the decorator.'.format(f), 1)
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
return f(*args, **kwargs)
# Other built-in modules are permanently whitelisted.
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
if (f in collections.__dict__.values() or f in pdb.__dict__.values() or
f in copy.__dict__.values()):
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
return f(*args, **kwargs)
# TODO(mdan): This needs cleanup.
# In particular, we may want to avoid renaming functions altogether.
if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
# TODO(mdan): This may be inconsistent in certain situations.

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import gc
@ -26,6 +27,7 @@ import numpy as np
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.utils import py_func
from tensorflow.python.framework import constant_op
@ -46,7 +48,7 @@ class TestResource(str):
class ApiTest(test.TestCase):
@test_util.run_deprecated_v1
def test_decorator_recurses(self):
def test_decorator_recursive(self):
class TestClass(object):
@ -69,7 +71,7 @@ class ApiTest(test.TestCase):
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1
def test_decorator_does_not_recurse(self):
def test_decorator_not_recursive(self):
class TestClass(object):
@ -90,7 +92,7 @@ class ApiTest(test.TestCase):
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1
def test_decorator_calls_unconverted_graph(self):
def test_convert_then_do_not_convert_graph(self):
class TestClass(object):
@ -105,14 +107,13 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
x = tc.test_method(
constant_op.constant((2, 4)), constant_op.constant(1),
constant_op.constant(-2))
self.assertAllEqual((0, 1), self.evaluate(x))
@test_util.run_deprecated_v1
def test_decorator_calls_unconverted_py_func(self):
def test_convert_then_do_not_convert_py_func(self):
class TestClass(object):
@ -132,11 +133,10 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
x = tc.test_method(
constant_op.constant((2, 4)), constant_op.constant(1),
constant_op.constant(-2))
self.assertAllEqual((0, 1), self.evaluate(x))
@test_util.run_deprecated_v1
def test_decorator_calls_decorated(self):
@ -265,6 +265,26 @@ class ApiTest(test.TestCase):
converter.ConversionOptions(), tc)
self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_converts_recursively(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def other_method(self):
if self.x < 0:
return -self.x
return self.x
def test_method(self):
return self.other_method()
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc.test_method, None,
converter.ConversionOptions(recursive=True), tc)
self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_by_class(self):
class TestClass(object):
@ -334,6 +354,22 @@ class ApiTest(test.TestCase):
constant_op.constant(0))
self.assertTrue(self.evaluate(x))
def test_converted_call_then_already_converted_dynamic(self):
@api.convert()
def g(x):
if x > 0:
return x
else:
return -x
def f(g, x):
return g(x)
x = api.converted_call(f, None, converter.ConversionOptions(),
g, constant_op.constant(1))
self.assertEqual(self.evaluate(x), 1)
@test_util.run_deprecated_v1
def test_converted_call_no_user_code(self):
@ -397,6 +433,24 @@ class ApiTest(test.TestCase):
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
def test_converted_call_namedtuple(self):
opts = converter.ConversionOptions()
x = api.converted_call(collections.namedtuple, None, opts,
'TestNamedtuple', ('a', 'b'))
self.assertTrue(inspect_utils.isnamedtuple(x))
def test_converted_call_namedtuple_via_collections(self):
opts = converter.ConversionOptions()
x = api.converted_call('namedtuple', collections, opts,
'TestNamedtuple', ('a', 'b'))
self.assertTrue(inspect_utils.isnamedtuple(x))
def test_converted_call_lambda(self):
opts = converter.ConversionOptions()

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import functools
import imp
# import types
import unittest
import gast
@ -87,17 +86,17 @@ def is_whitelisted_for_graph(o):
# Builtins typically have unnamed modules.
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
if m.__name__.startswith(prefix):
logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix)
logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
return True
# Temporary -- whitelist tensorboard modules.
# TODO(b/122731813): Remove.
if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
logging.log(2, '%s is whitelisted: name contains "tensorboard"', o)
logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
return True
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
logging.log(2, '%s is whitelisted: already converted', o)
logging.log(2, 'Whitelisted: %s: already converted', o)
return True
if hasattr(o, '__call__'):
@ -105,9 +104,10 @@ def is_whitelisted_for_graph(o):
# The type check avoids infinite recursion around the __call__ method
# of function objects.
if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck
logging.log(2, '%s is whitelisted: object __call__ whitelisted', o)
logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
return True
owner_class = None
if tf_inspect.ismethod(o):
# Methods of whitelisted classes are also whitelisted, even if they are
# bound via user subclasses.
@ -127,12 +127,12 @@ def is_whitelisted_for_graph(o):
owner_class = inspect_utils.getmethodclass(o)
if owner_class is not None:
if issubclass(owner_class, unittest.TestCase):
logging.log(2, '%s is whitelisted: method of TestCase subclass', o)
logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
return True
owner_class = inspect_utils.getdefiningclass(o, owner_class)
if is_whitelisted_for_graph(owner_class):
logging.log(2, '%s is whitelisted: owner is whitelisted %s', o,
logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
owner_class)
return True
@ -145,10 +145,10 @@ def is_whitelisted_for_graph(o):
'Entity {} looks like a namedtuple subclass. Its constructor will'
' not be converted by AutoGraph, but if it has any custom methods,'
' those will be.'.format(o), 1)
logging.log(2, '%s is whitelisted: named tuple', o)
logging.log(2, 'Whitelisted: %s: named tuple', o)
return True
logging.log(2, '%s is NOT whitelisted', o)
logging.log(2, 'Not whitelisted: %s: default rule', o)
return False