Whitelist functions that don't expose a __code__ object. This includes things like native bindings (such as NumPy), as well as harder-to-identify entities like method-wrapper objects. It also allows the converter to step into some of the TF op implementations without error.
PiperOrigin-RevId: 243121306
This commit is contained in:
parent
48cb1ae640
commit
fce41a2320
@ -53,8 +53,8 @@ DEFAULT_UNCOMPILED_MODULES = set((
|
||||
# TODO(mdan): Remove once the conversion process is optimized.
|
||||
('tensorflow_probability',),
|
||||
(_internal_name('tensorflow_probability'),),
|
||||
|
||||
# TODO(b/130313089): Remove.
|
||||
('numpy',),
|
||||
# TODO(mdan): Might need to add "thread" as well?
|
||||
('threading',),
|
||||
))
|
||||
|
@ -208,15 +208,12 @@ def _is_known_loaded_type(f, module_name, entity_name):
|
||||
|
||||
def converted_call(f, owner, options, args, kwargs):
|
||||
"""Compiles a function call inline. For internal use only."""
|
||||
logging.log(1,
|
||||
'Converted call: %s; owner: %s\n args: %s\n kwargs: %s\n',
|
||||
f, owner, args, kwargs)
|
||||
|
||||
if owner is not None:
|
||||
if not isinstance(f, str):
|
||||
raise ValueError(
|
||||
'When owner is specified, the function name must be specified as'
|
||||
' a string: {}'.format(f))
|
||||
owner_attr = f
|
||||
|
||||
# Special case when the owner is a 'super' object. In that case lookups of
|
||||
# dynamic attributes won't work. See
|
||||
@ -226,16 +223,22 @@ def converted_call(f, owner, options, args, kwargs):
|
||||
|
||||
f = getattr(owner, f)
|
||||
|
||||
if logging.has_verbosity(1):
|
||||
if owner is not None:
|
||||
composite_desc = '("{}" attr of {})'.format(owner_attr, owner)
|
||||
else:
|
||||
composite_desc = ''
|
||||
|
||||
logging.log(1,
|
||||
'Converted call: %s %s\n args: %s\n kwargs: %s\n',
|
||||
f, composite_desc, args, kwargs)
|
||||
|
||||
if inspect_utils.isbuiltin(f):
|
||||
if kwargs:
|
||||
return py_builtins.overload_of(f)(*args, **kwargs)
|
||||
else:
|
||||
return py_builtins.overload_of(f)(*args)
|
||||
|
||||
if _is_known_loaded_type(f, 'weakref', 'ref'):
|
||||
logging.log(2, 'Permanently whitelisted: %s: weakref', f)
|
||||
return _call_unconverted(f, args, kwargs)
|
||||
|
||||
# TODO(b/122265385): Remove this bypass.
|
||||
if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
|
||||
_is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
|
||||
@ -259,9 +262,7 @@ def converted_call(f, owner, options, args, kwargs):
|
||||
|
||||
# Other built-in modules are permanently whitelisted.
|
||||
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
||||
# Note: TF linter disallows importing inspect.
|
||||
if any(f in m.__dict__.values()
|
||||
for m in (collections, pdb, copy, tf_inspect._inspect)): # pylint:disable=protected-access
|
||||
if any(f in m.__dict__.values() for m in (collections, pdb, copy, inspect)):
|
||||
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
|
||||
return _call_unconverted(f, args, kwargs)
|
||||
|
||||
@ -318,6 +319,12 @@ def converted_call(f, owner, options, args, kwargs):
|
||||
target_entity = f
|
||||
raise NotImplementedError('unknown callable type "%s"' % type(f))
|
||||
|
||||
if (not tf_inspect.isclass(target_entity) and
|
||||
not hasattr(target_entity, '__code__')):
|
||||
logging.log(
|
||||
2, 'Permanently whitelisted: %s: native binding', target_entity)
|
||||
return _call_unconverted(f, args, kwargs)
|
||||
|
||||
converted_f = to_graph(
|
||||
target_entity,
|
||||
recursive=options.recursive,
|
||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||
import collections
|
||||
import functools
|
||||
import gc
|
||||
import os
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -36,6 +38,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import tf_inspect
|
||||
@ -267,6 +270,39 @@ class ApiTest(test.TestCase):
|
||||
(), {})
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_synthetic_method(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
def test_function(self):
|
||||
if self.x < 0:
|
||||
return -self.x
|
||||
return self.x
|
||||
|
||||
tc = TestClass(constant_op.constant(-1))
|
||||
test_method = types.MethodType(test_function, tc)
|
||||
|
||||
x = api.converted_call(test_method, None, converter.ConversionOptions(),
|
||||
(), {})
|
||||
self.assertEqual(1, self.evaluate(x))
|
||||
|
||||
def test_converted_call_method_wrapper(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def foo(self):
|
||||
pass
|
||||
|
||||
tc = TestClass()
|
||||
|
||||
# `method.__get__()` returns a so-called method-wrapper.
|
||||
wrapper = api.converted_call(
|
||||
'__get__', tc.foo, converter.ConversionOptions(), (tc,), {})
|
||||
self.assertEqual(wrapper, tc.foo)
|
||||
|
||||
def test_converted_call_method_as_object_attribute(self):
|
||||
|
||||
class AnotherClass(object):
|
||||
@ -453,6 +489,24 @@ class ApiTest(test.TestCase):
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
|
||||
|
||||
def test_converted_call_numpy(self):
|
||||
|
||||
opts = converter.ConversionOptions()
|
||||
|
||||
x = api.converted_call(np.arange, None, opts, (5,), {})
|
||||
|
||||
self.assertAllEqual(x, list(range(5)))
|
||||
|
||||
def test_converted_call_tf_op_forced(self):
|
||||
|
||||
# TODO(mdan): Add the missing level of support to LOGICAL_EXPRESSIONS.
|
||||
opts = converter.ConversionOptions(
|
||||
force_conversion=True, optional_features=None)
|
||||
|
||||
x = api.converted_call(gen_math_ops.add, None, opts, (1, 1), {})
|
||||
|
||||
self.assertAllEqual(self.evaluate(x), 2)
|
||||
|
||||
def test_converted_call_namedtuple(self):
|
||||
|
||||
opts = converter.ConversionOptions()
|
||||
@ -684,4 +738,5 @@ class ApiTest(test.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
||||
test.main()
|
||||
|
@ -435,17 +435,12 @@ def convert_entity_to_ast(o, program_ctx):
|
||||
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
|
||||
elif tf_inspect.ismethod(o):
|
||||
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
|
||||
# TODO(mdan,yashkatariya): Remove when object conversion is implemented.
|
||||
elif hasattr(o, '__class__'):
|
||||
# Note: this should only be raised when attempting to convert the object
|
||||
# directly. converted_call should still support it.
|
||||
raise NotImplementedError(
|
||||
'Object conversion is not yet supported. If you are '
|
||||
'trying to convert code that uses an existing object, '
|
||||
'try including the creation of that object in the '
|
||||
'conversion. For example, instead of converting the method '
|
||||
'of a class, try converting the entire class instead. '
|
||||
'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
|
||||
'python/autograph/README.md#using-the-functional-api '
|
||||
'for more information.')
|
||||
'cannot convert entity "{}": object conversion is not yet'
|
||||
' supported.'.format(o))
|
||||
else:
|
||||
raise ValueError(
|
||||
'Entity "%s" has unsupported type "%s". Only functions and classes are '
|
||||
|
Loading…
Reference in New Issue
Block a user