Whitelist functions that don't expose a __code__ object. This includes things like native bindings (such as NumPy), as well as harder-to-identify entities like method-wrapper objects. It also allows the converter to step into some of the TF op implementations without error.

PiperOrigin-RevId: 243121306
This commit is contained in:
Dan Moldovan 2019-04-11 12:39:27 -07:00 committed by TensorFlower Gardener
parent 48cb1ae640
commit fce41a2320
4 changed files with 78 additions and 21 deletions

View File

@ -53,8 +53,8 @@ DEFAULT_UNCOMPILED_MODULES = set((
# TODO(mdan): Remove once the conversion process is optimized.
('tensorflow_probability',),
(_internal_name('tensorflow_probability'),),
# TODO(b/130313089): Remove.
('numpy',),
# TODO(mdan): Might need to add "thread" as well?
('threading',),
))

View File

@ -208,15 +208,12 @@ def _is_known_loaded_type(f, module_name, entity_name):
def converted_call(f, owner, options, args, kwargs):
"""Compiles a function call inline. For internal use only."""
logging.log(1,
'Converted call: %s; owner: %s\n args: %s\n kwargs: %s\n',
f, owner, args, kwargs)
if owner is not None:
if not isinstance(f, str):
raise ValueError(
'When owner is specified, the function name must be specified as'
' a string: {}'.format(f))
owner_attr = f
# Special case when the owner is a 'super' object. In that case lookups of
# dynamic attributes won't work. See
@ -226,16 +223,22 @@ def converted_call(f, owner, options, args, kwargs):
f = getattr(owner, f)
if logging.has_verbosity(1):
if owner is not None:
composite_desc = '("{}" attr of {})'.format(owner_attr, owner)
else:
composite_desc = ''
logging.log(1,
'Converted call: %s %s\n args: %s\n kwargs: %s\n',
f, composite_desc, args, kwargs)
if inspect_utils.isbuiltin(f):
if kwargs:
return py_builtins.overload_of(f)(*args, **kwargs)
else:
return py_builtins.overload_of(f)(*args)
if _is_known_loaded_type(f, 'weakref', 'ref'):
logging.log(2, 'Permanently whitelisted: %s: weakref', f)
return _call_unconverted(f, args, kwargs)
# TODO(b/122265385): Remove this bypass.
if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
_is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
@ -259,9 +262,7 @@ def converted_call(f, owner, options, args, kwargs):
# Other built-in modules are permanently whitelisted.
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
# Note: TF linter disallows importing inspect.
if any(f in m.__dict__.values()
for m in (collections, pdb, copy, tf_inspect._inspect)): # pylint:disable=protected-access
if any(f in m.__dict__.values() for m in (collections, pdb, copy, inspect)):
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
return _call_unconverted(f, args, kwargs)
@ -318,6 +319,12 @@ def converted_call(f, owner, options, args, kwargs):
target_entity = f
raise NotImplementedError('unknown callable type "%s"' % type(f))
if (not tf_inspect.isclass(target_entity) and
not hasattr(target_entity, '__code__')):
logging.log(
2, 'Permanently whitelisted: %s: native binding', target_entity)
return _call_unconverted(f, args, kwargs)
converted_f = to_graph(
target_entity,
recursive=options.recursive,

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import collections
import functools
import gc
import os
import types
import numpy as np
@ -36,6 +38,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@ -267,6 +270,39 @@ class ApiTest(test.TestCase):
(), {})
self.assertEqual(1, self.evaluate(x))
def test_converted_call_synthetic_method(self):
class TestClass(object):
def __init__(self, x):
self.x = x
def test_function(self):
if self.x < 0:
return -self.x
return self.x
tc = TestClass(constant_op.constant(-1))
test_method = types.MethodType(test_function, tc)
x = api.converted_call(test_method, None, converter.ConversionOptions(),
(), {})
self.assertEqual(1, self.evaluate(x))
def test_converted_call_method_wrapper(self):
class TestClass(object):
def foo(self):
pass
tc = TestClass()
# `method.__get__()` returns a so-called method-wrapper.
wrapper = api.converted_call(
'__get__', tc.foo, converter.ConversionOptions(), (tc,), {})
self.assertEqual(wrapper, tc.foo)
def test_converted_call_method_as_object_attribute(self):
class AnotherClass(object):
@ -453,6 +489,24 @@ class ApiTest(test.TestCase):
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
def test_converted_call_numpy(self):
opts = converter.ConversionOptions()
x = api.converted_call(np.arange, None, opts, (5,), {})
self.assertAllEqual(x, list(range(5)))
def test_converted_call_tf_op_forced(self):
# TODO(mdan): Add the missing level of support to LOGICAL_EXPRESSIONS.
opts = converter.ConversionOptions(
force_conversion=True, optional_features=None)
x = api.converted_call(gen_math_ops.add, None, opts, (1, 1), {})
self.assertAllEqual(self.evaluate(x), 2)
def test_converted_call_namedtuple(self):
opts = converter.ConversionOptions()
@ -684,4 +738,5 @@ class ApiTest(test.TestCase):
if __name__ == '__main__':
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
test.main()

View File

@ -435,17 +435,12 @@ def convert_entity_to_ast(o, program_ctx):
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
elif tf_inspect.ismethod(o):
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
# TODO(mdan,yashkatariya): Remove when object conversion is implemented.
elif hasattr(o, '__class__'):
# Note: this should only be raised when attempting to convert the object
# directly. converted_call should still support it.
raise NotImplementedError(
'Object conversion is not yet supported. If you are '
'trying to convert code that uses an existing object, '
'try including the creation of that object in the '
'conversion. For example, instead of converting the method '
'of a class, try converting the entire class instead. '
'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
'python/autograph/README.md#using-the-functional-api '
'for more information.')
'cannot convert entity "{}": object conversion is not yet'
' supported.'.format(o))
else:
raise ValueError(
'Entity "%s" has unsupported type "%s". Only functions and classes are '