Use eager_py_func instead of the deprecated py_func. Add a bit of extra logic to account for the potential use of tf.print in Python 2.
PiperOrigin-RevId: 224574979
This commit is contained in:
parent
a64a8d8d02
commit
71ea120a1b
tensorflow/python/autograph
@ -174,6 +174,7 @@ def _tf_py_func_print(objects, kwargs):
|
||||
override_kwargs['flush'] = True
|
||||
|
||||
def print_wrapper(*vals):
|
||||
vals = tuple(v.numpy() if tensor_util.is_tensor(v) else v for v in vals)
|
||||
if six.PY3:
|
||||
# TensorFlow doesn't seem to generate Unicode when passing strings to
|
||||
# py_func. This causes the print to add a "b'" wrapper to the output,
|
||||
@ -193,6 +194,7 @@ def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED):
|
||||
|
||||
|
||||
def _tf_range(start_or_stop, stop, step):
|
||||
"""Overload of range_ that generates a TF range tensor."""
|
||||
# Note: for static inputs (e.g. constants), tf.range errors out at graph
|
||||
# construction time, instead of returning an empty tensor. Preventing the
|
||||
# graph construction error aligns the semantics with Python.
|
||||
|
@ -24,6 +24,7 @@ from __future__ import print_function
|
||||
import textwrap
|
||||
|
||||
import gast
|
||||
import six
|
||||
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
@ -91,7 +92,17 @@ def parse_entity(entity):
|
||||
def parse_str(src):
|
||||
"""Returns the AST of given piece of code."""
|
||||
# TODO(mdan): This should exclude the module things are autowrapped in.
|
||||
return gast.parse(src)
|
||||
|
||||
if six.PY2 and '.print(' in src:
|
||||
# This special treatment is required because gast.parse is not aware of
|
||||
# whether print_function was present in the original context.
|
||||
src = 'from __future__ import print_function\n' + src
|
||||
parsed_module = gast.parse(src)
|
||||
parsed_module.body = parsed_module.body[1:]
|
||||
else:
|
||||
parsed_module = gast.parse(src)
|
||||
|
||||
return parsed_module
|
||||
|
||||
|
||||
def parse_expression(src):
|
||||
|
@ -127,5 +127,6 @@ def wrap_py_func(f, return_dtypes, args, kwargs=None, use_dummy_return=False):
|
||||
retval = f(*f_args, **f_kwargs)
|
||||
return 1 if use_dummy_return else retval
|
||||
|
||||
return script_ops.py_func(f_wrapper, tensor_args, dtypes.int64
|
||||
if use_dummy_return else return_dtypes)
|
||||
if use_dummy_return:
|
||||
return_dtypes = dtypes.int32
|
||||
return script_ops.eager_py_func(f_wrapper, tensor_args, return_dtypes)
|
||||
|
@ -32,13 +32,13 @@ class PyFuncTest(test.TestCase):
|
||||
return a + b + c
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int32,
|
||||
(1, constant_op.constant(1), 1))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int32, (1, 1, 1))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(
|
||||
test_fn, dtypes.int64,
|
||||
test_fn, dtypes.int32,
|
||||
(constant_op.constant(1), 1, constant_op.constant(1)))
|
||||
self.assertEqual(3, self.evaluate(result))
|
||||
|
||||
@ -53,9 +53,9 @@ class PyFuncTest(test.TestCase):
|
||||
return a * b.foo
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int32, (7, TestClass()))
|
||||
self.assertEqual(35, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int32,
|
||||
(constant_op.constant(7), TestClass()))
|
||||
self.assertEqual(35, self.evaluate(result))
|
||||
|
||||
@ -70,12 +70,12 @@ class PyFuncTest(test.TestCase):
|
||||
return a * b.foo + c * d.foo
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), {
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int32, (7, TestClass(5)), {
|
||||
'c': 11,
|
||||
'd': TestClass(13)
|
||||
})
|
||||
self.assertEqual(178, self.evaluate(result))
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int64,
|
||||
result = py_func.wrap_py_func(test_fn, dtypes.int32,
|
||||
(constant_op.constant(7), TestClass(5)), {
|
||||
'c': constant_op.constant(11),
|
||||
'd': TestClass(13)
|
||||
|
Loading…
Reference in New Issue
Block a user