Minor logging cleanup, add a few extra tests.

PiperOrigin-RevId: 232395958
This commit is contained in:
Dan Moldovan 2019-02-04 17:12:27 -08:00 committed by TensorFlower Gardener
parent 038a9fd67b
commit 6f848e3e8b
3 changed files with 79 additions and 24 deletions

View File

@ -193,16 +193,17 @@ def converted_call(f, owner, options, *args, **kwargs):
'Entity {} appears to be decorated by wrapt, which is not yet supported' 'Entity {} appears to be decorated by wrapt, which is not yet supported'
' by AutoGraph. The function will be called without transformation.' ' by AutoGraph. The function will be called without transformation.'
' You may however apply AutoGraph before the decorator.'.format(f), 1) ' You may however apply AutoGraph before the decorator.'.format(f), 1)
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
return f(*args, **kwargs) return f(*args, **kwargs)
# Other built-in modules are permanently whitelisted. # Other built-in modules are permanently whitelisted.
# TODO(mdan): Figure out how to do this consistently for all stdlib modules. # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
if (f in collections.__dict__.values() or f in pdb.__dict__.values() or if (f in collections.__dict__.values() or f in pdb.__dict__.values() or
f in copy.__dict__.values()): f in copy.__dict__.values()):
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
return f(*args, **kwargs) return f(*args, **kwargs)
# TODO(mdan): This needs cleanup. # TODO(mdan): This needs cleanup.
# In particular, we may want to avoid renaming functions altogether.
if not options.force_conversion and conversion.is_whitelisted_for_graph(f): if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
# TODO(mdan): This may be inconsistent in certain situations. # TODO(mdan): This may be inconsistent in certain situations.

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import functools import functools
import gc import gc
@ -26,6 +27,7 @@ import numpy as np
from tensorflow.python.autograph import utils from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.utils import py_func from tensorflow.python.autograph.utils import py_func
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
@ -46,7 +48,7 @@ class TestResource(str):
class ApiTest(test.TestCase): class ApiTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def test_decorator_recurses(self): def test_decorator_recursive(self):
class TestClass(object): class TestClass(object):
@ -69,7 +71,7 @@ class ApiTest(test.TestCase):
self.assertListEqual([0, 1], self.evaluate(x).tolist()) self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def test_decorator_does_not_recurse(self): def test_decorator_not_recursive(self):
class TestClass(object): class TestClass(object):
@ -90,7 +92,7 @@ class ApiTest(test.TestCase):
self.assertListEqual([0, 1], self.evaluate(x).tolist()) self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def test_decorator_calls_unconverted_graph(self): def test_convert_then_do_not_convert_graph(self):
class TestClass(object): class TestClass(object):
@ -105,14 +107,13 @@ class ApiTest(test.TestCase):
return x return x
tc = TestClass() tc = TestClass()
with self.cached_session() as sess: x = tc.test_method(
x = tc.test_method( constant_op.constant((2, 4)), constant_op.constant(1),
constant_op.constant([2, 4]), constant_op.constant(1), constant_op.constant(-2))
constant_op.constant(-2)) self.assertAllEqual((0, 1), self.evaluate(x))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def test_decorator_calls_unconverted_py_func(self): def test_convert_then_do_not_convert_py_func(self):
class TestClass(object): class TestClass(object):
@ -132,11 +133,10 @@ class ApiTest(test.TestCase):
return x return x
tc = TestClass() tc = TestClass()
with self.cached_session() as sess: x = tc.test_method(
x = tc.test_method( constant_op.constant((2, 4)), constant_op.constant(1),
constant_op.constant([2, 4]), constant_op.constant(1), constant_op.constant(-2))
constant_op.constant(-2)) self.assertAllEqual((0, 1), self.evaluate(x))
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def test_decorator_calls_decorated(self): def test_decorator_calls_decorated(self):
@ -265,6 +265,26 @@ class ApiTest(test.TestCase):
converter.ConversionOptions(), tc) converter.ConversionOptions(), tc)
self.assertEqual(1, self.evaluate(x)) self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_converts_recursively(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def other_method(self):
if self.x < 0:
return -self.x
return self.x
def test_method(self):
return self.other_method()
tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc.test_method, None,
converter.ConversionOptions(recursive=True), tc)
self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_by_class(self): def test_converted_call_method_by_class(self):
class TestClass(object): class TestClass(object):
@ -334,6 +354,22 @@ class ApiTest(test.TestCase):
constant_op.constant(0)) constant_op.constant(0))
self.assertTrue(self.evaluate(x)) self.assertTrue(self.evaluate(x))
def test_converted_call_then_already_converted_dynamic(self):
@api.convert()
def g(x):
if x > 0:
return x
else:
return -x
def f(g, x):
return g(x)
x = api.converted_call(f, None, converter.ConversionOptions(),
g, constant_op.constant(1))
self.assertEqual(self.evaluate(x), 1)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def test_converted_call_no_user_code(self): def test_converted_call_no_user_code(self):
@ -397,6 +433,24 @@ class ApiTest(test.TestCase):
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x)) self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
def test_converted_call_namedtuple(self):
opts = converter.ConversionOptions()
x = api.converted_call(collections.namedtuple, None, opts,
'TestNamedtuple', ('a', 'b'))
self.assertTrue(inspect_utils.isnamedtuple(x))
def test_converted_call_namedtuple_via_collections(self):
opts = converter.ConversionOptions()
x = api.converted_call('namedtuple', collections, opts,
'TestNamedtuple', ('a', 'b'))
self.assertTrue(inspect_utils.isnamedtuple(x))
def test_converted_call_lambda(self): def test_converted_call_lambda(self):
opts = converter.ConversionOptions() opts = converter.ConversionOptions()

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import functools import functools
import imp import imp
# import types
import unittest import unittest
import gast import gast
@ -87,17 +86,17 @@ def is_whitelisted_for_graph(o):
# Builtins typically have unnamed modules. # Builtins typically have unnamed modules.
for prefix, in config.DEFAULT_UNCOMPILED_MODULES: for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
if m.__name__.startswith(prefix): if m.__name__.startswith(prefix):
logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix) logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
return True return True
# Temporary -- whitelist tensorboard modules. # Temporary -- whitelist tensorboard modules.
# TODO(b/122731813): Remove. # TODO(b/122731813): Remove.
if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__: if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
logging.log(2, '%s is whitelisted: name contains "tensorboard"', o) logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
return True return True
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'): if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
logging.log(2, '%s is whitelisted: already converted', o) logging.log(2, 'Whitelisted: %s: already converted', o)
return True return True
if hasattr(o, '__call__'): if hasattr(o, '__call__'):
@ -105,9 +104,10 @@ def is_whitelisted_for_graph(o):
# The type check avoids infinite recursion around the __call__ method # The type check avoids infinite recursion around the __call__ method
# of function objects. # of function objects.
if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck
logging.log(2, '%s is whitelisted: object __call__ whitelisted', o) logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
return True return True
owner_class = None
if tf_inspect.ismethod(o): if tf_inspect.ismethod(o):
# Methods of whitelisted classes are also whitelisted, even if they are # Methods of whitelisted classes are also whitelisted, even if they are
# bound via user subclasses. # bound via user subclasses.
@ -127,12 +127,12 @@ def is_whitelisted_for_graph(o):
owner_class = inspect_utils.getmethodclass(o) owner_class = inspect_utils.getmethodclass(o)
if owner_class is not None: if owner_class is not None:
if issubclass(owner_class, unittest.TestCase): if issubclass(owner_class, unittest.TestCase):
logging.log(2, '%s is whitelisted: method of TestCase subclass', o) logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
return True return True
owner_class = inspect_utils.getdefiningclass(o, owner_class) owner_class = inspect_utils.getdefiningclass(o, owner_class)
if is_whitelisted_for_graph(owner_class): if is_whitelisted_for_graph(owner_class):
logging.log(2, '%s is whitelisted: owner is whitelisted %s', o, logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
owner_class) owner_class)
return True return True
@ -145,10 +145,10 @@ def is_whitelisted_for_graph(o):
'Entity {} looks like a namedtuple subclass. Its constructor will' 'Entity {} looks like a namedtuple subclass. Its constructor will'
' not be converted by AutoGraph, but if it has any custom methods,' ' not be converted by AutoGraph, but if it has any custom methods,'
' those will be.'.format(o), 1) ' those will be.'.format(o), 1)
logging.log(2, '%s is whitelisted: named tuple', o) logging.log(2, 'Whitelisted: %s: named tuple', o)
return True return True
logging.log(2, '%s is NOT whitelisted', o) logging.log(2, 'Not whitelisted: %s: default rule', o)
return False return False