Add tfe.py_func, a tf.py_func-like construct that wraps a Python function and executes it eagerly.
In particular, an EagerPyFunc op is added that wraps a Python function and executes it eagerly. The wrapped function should take Tensors as inputs and return Tensors as outputs. Because functions wrapped in an EagerPyFunc are executed eagerly, they can make use of TensorFlow operations. EagerPyFunc should be differentiable, in principle; a gradient will be implemented and registered in a future change. Once a gradient is implemented, tfe.py_func will probably be the easiest mechanism for experimenting with custom ops. tfe.py_func will also make it easier to translate python functions with side-effects into defun-able code. PiperOrigin-RevId: 178303818
This commit is contained in:
parent
2d4c29cd6a
commit
f37380b064
@ -19,6 +19,7 @@ py_library(
|
|||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:numerics",
|
"//tensorflow/python:numerics",
|
||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
|
"//tensorflow/python:script_ops",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python/eager:backprop",
|
"//tensorflow/python/eager:backprop",
|
||||||
|
@ -23,6 +23,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
|||||||
@@list_devices
|
@@list_devices
|
||||||
@@num_gpus
|
@@num_gpus
|
||||||
|
|
||||||
|
@@py_func
|
||||||
@@defun
|
@@defun
|
||||||
@@implicit_gradients
|
@@implicit_gradients
|
||||||
@@implicit_value_and_gradients
|
@@implicit_value_and_gradients
|
||||||
@ -101,8 +102,10 @@ from tensorflow.python.framework.test_util import IsolateTest
|
|||||||
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
|
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
|
||||||
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
|
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
|
||||||
from tensorflow.python.ops.variable_scope import EagerVariableStore
|
from tensorflow.python.ops.variable_scope import EagerVariableStore
|
||||||
|
from tensorflow.python.ops import script_ops
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
py_func = script_ops.eager_py_func
|
||||||
defun = function.defun
|
defun = function.defun
|
||||||
implicit_gradients = backprop.implicit_grad
|
implicit_gradients = backprop.implicit_grad
|
||||||
implicit_value_and_gradients = backprop.implicit_val_and_grad
|
implicit_value_and_gradients = backprop.implicit_val_and_grad
|
||||||
|
@ -0,0 +1,8 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "EagerPyFunc"
|
||||||
|
summary: "Eagerly executes a python function to compute func(input)->output. The"
|
||||||
|
description: <<END
|
||||||
|
semantics of the input, output, and attributes are the same as those for
|
||||||
|
PyFunc.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "EagerPyFunc"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -51,4 +51,18 @@ REGISTER_OP("PyFuncStateless")
|
|||||||
A stateless version of PyFunc.
|
A stateless version of PyFunc.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("EagerPyFunc")
|
||||||
|
.Input("input: Tin")
|
||||||
|
.Output("output: Tout")
|
||||||
|
.Attr("token: string")
|
||||||
|
.Attr("Tin: list(type) >= 0")
|
||||||
|
.Attr("Tout: list(type) >=0")
|
||||||
|
.SetIsStateful()
|
||||||
|
.SetShapeFn(shape_inference::UnknownShape)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Eagerly executes a python function to compute func(input)->output. The
|
||||||
|
semantics of the input, output, and attributes are the same as those for
|
||||||
|
PyFunc.
|
||||||
|
)doc");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -280,10 +280,14 @@ cc_library(
|
|||||||
":ndarray_tensor_bridge",
|
":ndarray_tensor_bridge",
|
||||||
":numpy_lib",
|
":numpy_lib",
|
||||||
":py_util",
|
":py_util",
|
||||||
|
":safe_ptr",
|
||||||
|
"//tensorflow/c:tf_status_helper",
|
||||||
|
"//tensorflow/c/eager:c_api",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:script_ops_op_lib",
|
"//tensorflow/core:script_ops_op_lib",
|
||||||
|
"//tensorflow/python/eager:pywrap_tfe_lib",
|
||||||
"//third_party/py/numpy:headers",
|
"//third_party/py/numpy:headers",
|
||||||
"//util/python:python_headers",
|
"//util/python:python_headers",
|
||||||
],
|
],
|
||||||
|
@ -1645,6 +1645,8 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:script_ops",
|
"//tensorflow/python:script_ops",
|
||||||
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/eager:function",
|
||||||
],
|
],
|
||||||
tags = ["no_windows"],
|
tags = ["no_windows"],
|
||||||
)
|
)
|
||||||
|
@ -23,82 +23,93 @@ from six.moves import queue
|
|||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import script_ops
|
from tensorflow.python.ops import script_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class PyOpTest(test.TestCase):
|
def np_func(x, y):
|
||||||
|
return np.sinh(x) + np.cosh(y)
|
||||||
|
|
||||||
def testBasic(self):
|
|
||||||
|
|
||||||
def my_func(x, y):
|
def matmul(x, y):
|
||||||
return np.sinh(x) + np.cosh(y)
|
return math_ops.matmul(x, y)
|
||||||
|
|
||||||
# single type
|
|
||||||
|
class PyFuncTest(test.TestCase):
|
||||||
|
"""Encapsulates tests for py_func and eager_py_func."""
|
||||||
|
|
||||||
|
# ----- Tests for py_func -----
|
||||||
|
def testSingleType(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = constant_op.constant(1.0, dtypes.float32)
|
x = constant_op.constant(1.0, dtypes.float32)
|
||||||
y = constant_op.constant(2.0, dtypes.float32)
|
y = constant_op.constant(2.0, dtypes.float32)
|
||||||
z = script_ops.py_func(my_func, [x, y], dtypes.float32)
|
z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
|
||||||
self.assertEqual(z.eval(), my_func(1.0, 2.0).astype(np.float32))
|
self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
|
||||||
|
|
||||||
# scalar
|
def testScalar(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = constant_op.constant(1.0, dtypes.float32)
|
x = constant_op.constant(1.0, dtypes.float32)
|
||||||
y = constant_op.constant(2.0, dtypes.float32)
|
y = constant_op.constant(2.0, dtypes.float32)
|
||||||
z = script_ops.py_func(my_func, [x, y], [dtypes.float32])
|
z = self.evaluate(
|
||||||
self.assertEqual(z[0].eval(), my_func(1.0, 2.0).astype(np.float32))
|
script_ops.eager_py_func(np_func, [x, y], [dtypes.float32]))
|
||||||
|
self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
|
||||||
|
|
||||||
# array
|
def testArray(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = constant_op.constant([1.0, 2.0], dtypes.float64)
|
x = constant_op.constant([1.0, 2.0], dtypes.float64)
|
||||||
y = constant_op.constant([2.0, 3.0], dtypes.float64)
|
y = constant_op.constant([2.0, 3.0], dtypes.float64)
|
||||||
z = script_ops.py_func(my_func, [x, y], [dtypes.float64])
|
z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
|
||||||
self.assertAllEqual(z[0].eval(),
|
self.assertAllEqual(z[0],
|
||||||
my_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
|
np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
|
||||||
|
|
||||||
# a bit exotic type (complex64)
|
def testComplexType(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = constant_op.constant(1 + 2j, dtypes.complex64)
|
x = constant_op.constant(1 + 2j, dtypes.complex64)
|
||||||
y = constant_op.constant(3 + 4j, dtypes.complex64)
|
y = constant_op.constant(3 + 4j, dtypes.complex64)
|
||||||
z, = script_ops.py_func(my_func, [x, y], [dtypes.complex64])
|
z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
|
||||||
self.assertAllClose(z.eval(), my_func(1 + 2j, 3 + 4j))
|
self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
|
||||||
|
|
||||||
# a bit excotic function (rfft)
|
def testRFFT(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
|
x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
|
||||||
|
|
||||||
def rfft(x):
|
def rfft(x):
|
||||||
return np.fft.rfft(x).astype(np.complex64)
|
return np.fft.rfft(x).astype(np.complex64)
|
||||||
|
|
||||||
y, = script_ops.py_func(rfft, [x], [dtypes.complex64])
|
y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64))
|
||||||
self.assertAllClose(y.eval(), np.fft.rfft([1., 2., 3., 4.]))
|
self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
|
||||||
|
|
||||||
# returns a python literal.
|
def testPythonLiteral(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
|
||||||
def literal(x):
|
def literal(x):
|
||||||
return 1.0 if x == 0.0 else 0.0
|
return 1.0 if float(x) == 0.0 else 0.0
|
||||||
|
|
||||||
x = constant_op.constant(0.0, dtypes.float64)
|
x = constant_op.constant(0.0, dtypes.float64)
|
||||||
y, = script_ops.py_func(literal, [x], [dtypes.float64])
|
y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64))
|
||||||
self.assertAllClose(y.eval(), 1.0)
|
self.assertAllClose(y, 1.0)
|
||||||
|
|
||||||
# returns a list
|
def testList(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
|
||||||
def list_func(x):
|
def list_func(x):
|
||||||
return [x, x + 1]
|
return [x, x + 1]
|
||||||
|
|
||||||
x = constant_op.constant(0.0, dtypes.float64)
|
x = constant_op.constant(0.0, dtypes.float64)
|
||||||
y, z = script_ops.py_func(list_func, [x], [dtypes.float64] * 2)
|
y = self.evaluate(
|
||||||
self.assertAllClose(y.eval(), 0.0)
|
script_ops.py_func(list_func, [x], [dtypes.float64] * 2))
|
||||||
self.assertAllClose(z.eval(), 1.0)
|
self.assertAllClose(y, [0.0, 1.0])
|
||||||
|
|
||||||
|
def testTuple(self):
|
||||||
# returns a tuple
|
# returns a tuple
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
|
||||||
@ -106,17 +117,17 @@ class PyOpTest(test.TestCase):
|
|||||||
return x, x + 1
|
return x, x + 1
|
||||||
|
|
||||||
x = constant_op.constant(0.0, dtypes.float64)
|
x = constant_op.constant(0.0, dtypes.float64)
|
||||||
y, z = script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2)
|
y = self.evaluate(
|
||||||
self.assertAllClose(y.eval(), 0.0)
|
script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2))
|
||||||
self.assertAllClose(z.eval(), 1.0)
|
self.assertAllClose(y, [0.0, 1.0])
|
||||||
|
|
||||||
# returns a tuple, Tout and inp a tuple
|
# returns a tuple, Tout and inp a tuple
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = constant_op.constant(0.0, dtypes.float64)
|
x = constant_op.constant(0.0, dtypes.float64)
|
||||||
y, z = script_ops.py_func(tuple_func, (x,), (dtypes.float64,
|
y = self.evaluate(
|
||||||
dtypes.float64))
|
script_ops.py_func(tuple_func, (x,),
|
||||||
self.assertAllClose(y.eval(), 0.0)
|
(dtypes.float64, dtypes.float64)))
|
||||||
self.assertAllClose(z.eval(), 1.0)
|
self.assertAllClose(y, [0.0, 1.0])
|
||||||
|
|
||||||
def testStrings(self):
|
def testStrings(self):
|
||||||
|
|
||||||
@ -128,10 +139,12 @@ class PyOpTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = constant_op.constant([b"hello", b"hi"], dtypes.string)
|
x = constant_op.constant([b"hello", b"hi"], dtypes.string)
|
||||||
y, = script_ops.py_func(read_fixed_length_numpy_strings, [],
|
y = self.evaluate(
|
||||||
[dtypes.string])
|
script_ops.py_func(read_fixed_length_numpy_strings, [],
|
||||||
z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
|
dtypes.string))
|
||||||
self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"])
|
z = self.evaluate(
|
||||||
|
script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
|
||||||
|
self.assertAllEqual(z, [b"hello there", b"hi there"])
|
||||||
|
|
||||||
def testStringsAreConvertedToBytes(self):
|
def testStringsAreConvertedToBytes(self):
|
||||||
|
|
||||||
@ -143,10 +156,12 @@ class PyOpTest(test.TestCase):
|
|||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x = constant_op.constant(["hello", "hi"], dtypes.string)
|
x = constant_op.constant(["hello", "hi"], dtypes.string)
|
||||||
y, = script_ops.py_func(read_fixed_length_numpy_strings, [],
|
y = self.evaluate(
|
||||||
[dtypes.string])
|
script_ops.py_func(read_fixed_length_numpy_strings, [],
|
||||||
z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
|
dtypes.string))
|
||||||
self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"])
|
z = self.evaluate(
|
||||||
|
script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
|
||||||
|
self.assertAllEqual(z, [b"hello there", b"hi there"])
|
||||||
|
|
||||||
def testObjectArraysAreConvertedToBytes(self):
|
def testObjectArraysAreConvertedToBytes(self):
|
||||||
|
|
||||||
@ -186,16 +201,8 @@ class PyOpTest(test.TestCase):
|
|||||||
|
|
||||||
def testNoInput(self):
|
def testNoInput(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
x, = script_ops.py_func(lambda: 42.0, [], [dtypes.float64])
|
x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
|
||||||
self.assertAllClose(x.eval(), 42.0)
|
self.assertAllClose(x, 42.0)
|
||||||
|
|
||||||
def testCleanup(self):
|
|
||||||
for _ in xrange(1000):
|
|
||||||
g = ops.Graph()
|
|
||||||
with g.as_default():
|
|
||||||
c = constant_op.constant([1.], dtypes.float32)
|
|
||||||
_ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
|
|
||||||
self.assertTrue(script_ops._py_funcs.size() < 100)
|
|
||||||
|
|
||||||
def testAlias(self):
|
def testAlias(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
@ -242,8 +249,8 @@ class PyOpTest(test.TestCase):
|
|||||||
# Create a numpy array aliasing a tensor and a tensor aliasing this array
|
# Create a numpy array aliasing a tensor and a tensor aliasing this array
|
||||||
z, = script_ops.py_func(ident, [p], [dtypes.float32])
|
z, = script_ops.py_func(ident, [p], [dtypes.float32])
|
||||||
z += 0.0 # Makes sure we release the tensor aliasing the numpy array x[0]
|
z += 0.0 # Makes sure we release the tensor aliasing the numpy array x[0]
|
||||||
# above instead of using its memory as the return value of
|
# above instead of using its memory as the return value of
|
||||||
# session.run
|
# session.run
|
||||||
self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
|
self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
|
||||||
|
|
||||||
def testStateful(self):
|
def testStateful(self):
|
||||||
@ -319,10 +326,10 @@ class PyOpTest(test.TestCase):
|
|||||||
def value(self):
|
def value(self):
|
||||||
return self._value
|
return self._value
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session():
|
||||||
s = State()
|
s = State()
|
||||||
op = s.increment(constant_op.constant(2, dtypes.int64))
|
op = s.increment(constant_op.constant(2, dtypes.int64))
|
||||||
ret = sess.run(op)
|
ret = self.evaluate(op)
|
||||||
self.assertIsNone(ret)
|
self.assertIsNone(ret)
|
||||||
self.assertAllEqual([3], s.value)
|
self.assertAllEqual([3], s.value)
|
||||||
|
|
||||||
@ -336,15 +343,24 @@ class PyOpTest(test.TestCase):
|
|||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
self.assertEqual(sess.run(f), [])
|
self.assertEqual(sess.run(f), [])
|
||||||
|
|
||||||
def _testExceptionHandling(self, py_exp, tf_exp):
|
def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
|
||||||
|
|
||||||
def raise_exception():
|
def raise_exception():
|
||||||
raise py_exp("blah") # pylint: disable=not-callable
|
raise py_exp("blah") # pylint: disable=not-callable
|
||||||
|
|
||||||
f = script_ops.py_func(raise_exception, [], [])
|
if eager:
|
||||||
with self.test_session() as sess:
|
if context.in_eager_mode():
|
||||||
|
with self.assertRaisesRegexp(tf_exp, "blah"):
|
||||||
|
f = script_ops.eager_py_func(raise_exception, [], [])
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
f = script_ops.eager_py_func(raise_exception, [], [])
|
||||||
|
else:
|
||||||
|
f = script_ops.py_func(raise_exception, [], [])
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
with self.assertRaisesRegexp(tf_exp, "blah"):
|
with self.assertRaisesRegexp(tf_exp, "blah"):
|
||||||
sess.run(f)
|
self.evaluate(f)
|
||||||
|
|
||||||
def testExceptionHandling(self):
|
def testExceptionHandling(self):
|
||||||
self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
|
self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
|
||||||
@ -358,6 +374,89 @@ class PyOpTest(test.TestCase):
|
|||||||
|
|
||||||
self._testExceptionHandling(WeirdError, errors.UnknownError)
|
self._testExceptionHandling(WeirdError, errors.UnknownError)
|
||||||
|
|
||||||
|
# ----- Tests shared by py_func and eager_py_func -----
|
||||||
|
def testCleanup(self):
|
||||||
|
for _ in xrange(1000):
|
||||||
|
g = ops.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
c = constant_op.constant([1.], dtypes.float32)
|
||||||
|
_ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
|
||||||
|
_ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
|
||||||
|
self.assertTrue(script_ops._py_funcs.size() < 100)
|
||||||
|
|
||||||
|
# ----- Tests for eager_py_func -----
|
||||||
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
|
def testEagerSingleOutputInt32(self):
|
||||||
|
a = array_ops.ones((3, 3), dtype=dtypes.int32)
|
||||||
|
x = array_ops.ones((3, 1), dtype=dtypes.int32)
|
||||||
|
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
|
||||||
|
with self.test_session():
|
||||||
|
ret = self.evaluate(output)
|
||||||
|
self.assertAllEqual(ret, [[3], [3], [3]])
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
|
def testEagerSingleOutputFloat32(self):
|
||||||
|
a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
||||||
|
x = array_ops.ones((3, 1), dtype=dtypes.float32)
|
||||||
|
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
|
||||||
|
with self.test_session():
|
||||||
|
ret = self.evaluate(output)
|
||||||
|
self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
|
def testEagerArrayOutput(self):
|
||||||
|
a = array_ops.ones((3, 3), dtype=dtypes.int32)
|
||||||
|
x = array_ops.ones((3, 1), dtype=dtypes.int32)
|
||||||
|
output = script_ops.eager_py_func(
|
||||||
|
lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.int32])
|
||||||
|
|
||||||
|
with self.test_session():
|
||||||
|
ret = self.evaluate(output)
|
||||||
|
self.assertAllEqual(ret, [[[3], [3], [3]]])
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
|
def testEagerReturnNone(self):
|
||||||
|
|
||||||
|
def no_return_value():
|
||||||
|
return
|
||||||
|
|
||||||
|
output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
|
||||||
|
ret = self.evaluate(output)
|
||||||
|
if context.in_eager_mode():
|
||||||
|
self.assertEquals(len(ret), 0)
|
||||||
|
else:
|
||||||
|
self.assertIsNone(ret)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
|
def testEagerPyFuncInDefun(self):
|
||||||
|
|
||||||
|
def wrapper():
|
||||||
|
a = array_ops.ones((3, 3), dtype=dtypes.int32)
|
||||||
|
x = array_ops.ones((3, 1), dtype=dtypes.int32)
|
||||||
|
return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
|
||||||
|
|
||||||
|
wrapped = function.defun(wrapper)
|
||||||
|
ret = self.evaluate(wrapped())
|
||||||
|
self.assertAllEqual(ret, [[3], [3], [3]])
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
|
def testEagerExceptionHandling(self):
|
||||||
|
self._testExceptionHandling(
|
||||||
|
ValueError, errors.InvalidArgumentError, eager=True)
|
||||||
|
self._testExceptionHandling(
|
||||||
|
TypeError, errors.InvalidArgumentError, eager=True)
|
||||||
|
self._testExceptionHandling(
|
||||||
|
StopIteration, errors.OutOfRangeError, eager=True)
|
||||||
|
self._testExceptionHandling(
|
||||||
|
MemoryError, errors.ResourceExhaustedError, eager=True)
|
||||||
|
self._testExceptionHandling(
|
||||||
|
NotImplementedError, errors.UnimplementedError, eager=True)
|
||||||
|
|
||||||
|
class WeirdError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
|||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
#include "numpy/arrayobject.h"
|
#include "numpy/arrayobject.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -25,8 +27,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||||
#include "tensorflow/python/lib/core/py_util.h"
|
#include "tensorflow/python/lib/core/py_util.h"
|
||||||
|
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -48,6 +52,9 @@ struct PyCall {
|
|||||||
// with this "token".
|
// with this "token".
|
||||||
string token;
|
string token;
|
||||||
|
|
||||||
|
// True if the call is associated with an EagerPyFunc.
|
||||||
|
bool eager;
|
||||||
|
|
||||||
// Inputs and outputs of this function invocation.
|
// Inputs and outputs of this function invocation.
|
||||||
std::vector<Tensor> ins;
|
std::vector<Tensor> ins;
|
||||||
std::vector<Tensor> out;
|
std::vector<Tensor> out;
|
||||||
@ -55,19 +62,26 @@ struct PyCall {
|
|||||||
|
|
||||||
// Givens the 'call', prepares the token and inputs as a python tuple
|
// Givens the 'call', prepares the token and inputs as a python tuple
|
||||||
// that is appropriate for calling the trampoline.
|
// that is appropriate for calling the trampoline.
|
||||||
Status MakeArgTuple(PyCall* call, PyObject** tuple) {
|
Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
|
||||||
int64 n = call->ins.size();
|
int64 n = call->ins.size();
|
||||||
PyObject* lst = PyList_New(n);
|
PyObject* lst = PyList_New(n);
|
||||||
CHECK(lst);
|
CHECK(lst);
|
||||||
for (int64 i = 0; i < n; ++i) {
|
for (int64 i = 0; i < n; ++i) {
|
||||||
|
PyObject* arg = nullptr;
|
||||||
const Tensor& t = call->ins[i];
|
const Tensor& t = call->ins[i];
|
||||||
PyObject* a = nullptr;
|
if (call->eager) {
|
||||||
Status s = ConvertTensorToNdarray(t, &a);
|
arg = EagerTensorFromHandle(TFE_NewTensorHandle(t));
|
||||||
if (!s.ok()) {
|
if (arg == nullptr) {
|
||||||
Py_DECREF(lst);
|
return errors::Internal("Unable to procure EagerTensor from Tensor.");
|
||||||
return s;
|
}
|
||||||
|
} else {
|
||||||
|
Status s = ConvertTensorToNdarray(t, &arg);
|
||||||
|
if (!s.ok()) {
|
||||||
|
Py_DECREF(lst);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
PyList_SetItem(lst, i, a);
|
PyList_SetItem(lst, i, arg);
|
||||||
}
|
}
|
||||||
*tuple = Py_BuildValue("(sN)", call->token.c_str(), lst);
|
*tuple = Py_BuildValue("(sN)", call->token.c_str(), lst);
|
||||||
CHECK(*tuple);
|
CHECK(*tuple);
|
||||||
@ -133,6 +147,18 @@ bool IsSingleNone(PyObject* obj) {
|
|||||||
return item == Py_None;
|
return item == Py_None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
|
||||||
|
Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
|
||||||
|
Tensor* output_tensor,
|
||||||
|
TF_Status* tf_status) {
|
||||||
|
// TODO(akshayka): Lift the restriction requiring output tensors to
|
||||||
|
// lie in host memory; EagerPyFunc should be able to dispatch ops on GPU
|
||||||
|
// tensors, so we should eventually implement a GPU kernel for EagerPyFunc.
|
||||||
|
*output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||||
|
EagerTensor_Handle(eager_tensor), tf_status);
|
||||||
|
return StatusFromTF_Status(tf_status);
|
||||||
|
}
|
||||||
|
|
||||||
// Calls the registered py function through the trampoline.
|
// Calls the registered py function through the trampoline.
|
||||||
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
||||||
*out_log_on_error = true;
|
*out_log_on_error = true;
|
||||||
@ -172,21 +198,37 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the return values and converts them to tf Tensors.
|
// Process the return values and convert them to TF Tensors.
|
||||||
Status s;
|
Status s;
|
||||||
if (PyList_Check(result)) {
|
if (PyList_Check(result)) {
|
||||||
// 'result' is a list.
|
|
||||||
call->out.clear();
|
call->out.clear();
|
||||||
for (int i = 0; i < PyList_Size(result); ++i) {
|
for (int i = 0; i < PyList_Size(result); ++i) {
|
||||||
Tensor t;
|
Tensor t;
|
||||||
s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
|
if (call->eager) {
|
||||||
|
auto tf_status = tensorflow::make_safe(TF_NewStatus());
|
||||||
|
s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t,
|
||||||
|
tf_status.get());
|
||||||
|
} else {
|
||||||
|
s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
|
||||||
|
}
|
||||||
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
call->out.push_back(t);
|
call->out.push_back(t);
|
||||||
}
|
}
|
||||||
|
} else if (EagerTensor_CheckExact(result) || result == Py_None) {
|
||||||
|
DCHECK(call->eager);
|
||||||
|
Tensor t;
|
||||||
|
if (result != Py_None) {
|
||||||
|
auto tf_status = tensorflow::make_safe(TF_NewStatus());
|
||||||
|
s = ExtractTensorFromEagerTensor(result, &t, tf_status.get());
|
||||||
|
if (s.ok()) {
|
||||||
|
call->out.push_back(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (PyArray_Check(result)) {
|
} else if (PyArray_Check(result)) {
|
||||||
// 'result' is a single ndarray.
|
DCHECK(!call->eager);
|
||||||
if (!IsSingleNone(result)) {
|
if (!IsSingleNone(result)) {
|
||||||
Tensor t;
|
Tensor t;
|
||||||
s = ConvertNdarrayToTensor(result, &t);
|
s = ConvertNdarrayToTensor(result, &t);
|
||||||
@ -375,11 +417,13 @@ class PyFuncOp : public OpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
|
||||||
|
eager_ = type_string() == "EagerPyFunc";
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
PyCall call;
|
PyCall call;
|
||||||
call.token = token_;
|
call.token = token_;
|
||||||
|
call.eager = eager_;
|
||||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||||
call.ins.push_back(ctx->input(i));
|
call.ins.push_back(ctx->input(i));
|
||||||
}
|
}
|
||||||
@ -418,9 +462,15 @@ class PyFuncOp : public OpKernel {
|
|||||||
private:
|
private:
|
||||||
string token_;
|
string token_;
|
||||||
|
|
||||||
|
// True if and only if this op should execute the python function eagerly,
|
||||||
|
// i.e., if and only if the eager attribute is set.
|
||||||
|
bool eager_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
|
REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp);
|
REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp);
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -24,21 +24,27 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Called by py code on initialization.
|
// Called by python code on initialization.
|
||||||
//
|
//
|
||||||
// "trampoline" must represent a python function which has the
|
// "trampoline" must represent a python function which has the
|
||||||
// following signature:
|
// following signature:
|
||||||
// (string, list(ndarray)) -> ndarray | list(ndarray) | python scalar
|
// (string, list(ndarray)) | (string, list(EagerTensor)) ->
|
||||||
|
// ndarray | list(ndarray) | python scalar |
|
||||||
|
// EagerTensor | list(EagerTensor) | None
|
||||||
//
|
//
|
||||||
// The trampoline takes two arguments, the first is a string token
|
// The trampoline takes two arguments, the first is a string token
|
||||||
// used by the python frontend's dispatching logic; the second is a
|
// used by the python frontend's dispatching logic; the second is a
|
||||||
// list of numpy ndarrays.
|
// list of numpy ndarrays or EagerTensor objects. It can return a
|
||||||
|
// single numpy ndarray, a list of numpy ndarrays, a python scalar, an
|
||||||
|
// EagerTensor, a list of EagerTensors, or None.
|
||||||
//
|
//
|
||||||
// The trampoline can return a single numpy ndarray, a list of numpy
|
// PyFunc requires inputs and outputs to be ndarrays. EagerPyFunc requires
|
||||||
// ndarrays, or a simply python scalar. The C++ runtime converts them,
|
// inputs to be a list of EagerTensors and outputs to be an EagerTensor, a list
|
||||||
// if supported, back to Tensor objects.
|
// of EagerTensors, or None.
|
||||||
//
|
//
|
||||||
// This is called by script_ops.py during its module initialization.
|
// The C++ runtime converts outputs back to Tensor objects.
|
||||||
|
//
|
||||||
|
// This function is called by script_ops.py during its module initialization.
|
||||||
//
|
//
|
||||||
// TODO(zhifengc): Support distributed runtime.
|
// TODO(zhifengc): Support distributed runtime.
|
||||||
void InitializePyTrampoline(PyObject* trampoline);
|
void InitializePyTrampoline(PyObject* trampoline);
|
||||||
|
@ -341,6 +341,7 @@ TruncatedNormal
|
|||||||
# script_ops
|
# script_ops
|
||||||
PyFunc
|
PyFunc
|
||||||
PyFuncStateless
|
PyFuncStateless
|
||||||
|
EagerPyFunc
|
||||||
|
|
||||||
# sdca_ops
|
# sdca_ops
|
||||||
|
|
||||||
|
@ -29,11 +29,41 @@ import numpy as np
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gen_script_ops
|
from tensorflow.python.ops import gen_script_ops
|
||||||
|
|
||||||
|
|
||||||
|
class EagerFunc(object):
|
||||||
|
"""A wrapper for a function owned by an EagerPyFunc."""
|
||||||
|
|
||||||
|
def __init__(self, func, Tout):
|
||||||
|
"""Constructs an EagerFunc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: The function to wrap.
|
||||||
|
Tout: A list of datatypes for the output; an empty list if the output is
|
||||||
|
None.
|
||||||
|
"""
|
||||||
|
self._func = func
|
||||||
|
self._out_dtypes = Tout
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
"""Passes args, kwargs to `self._func`, which is executed eagerly."""
|
||||||
|
with context.eager_mode():
|
||||||
|
ret = self._func(*args, **kwargs)
|
||||||
|
if isinstance(ret, (tuple, list)):
|
||||||
|
return [
|
||||||
|
ops.convert_to_tensor(x, dtype=dtype)
|
||||||
|
for (x, dtype) in zip(ret, self._out_dtypes)
|
||||||
|
]
|
||||||
|
elif ret is None:
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
return ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])
|
||||||
|
|
||||||
|
|
||||||
class FuncRegistry(object):
|
class FuncRegistry(object):
|
||||||
"""A helper class to keep track of registered py functions.
|
"""A helper class to keep track of registered py functions.
|
||||||
|
|
||||||
@ -91,16 +121,20 @@ class FuncRegistry(object):
|
|||||||
if func is None:
|
if func is None:
|
||||||
raise ValueError("callback %s is not found" % token)
|
raise ValueError("callback %s is not found" % token)
|
||||||
ret = func(*args)
|
ret = func(*args)
|
||||||
# Strings seem to lead to a memory leak here if they're not wrapped in a
|
|
||||||
# list.
|
if isinstance(func, EagerFunc):
|
||||||
if isinstance(ret, six.binary_type):
|
return ret
|
||||||
ret = [ret]
|
|
||||||
# Ensures that we return either a single numpy array or a list of numpy
|
|
||||||
# arrays.
|
|
||||||
if isinstance(ret, (tuple, list)):
|
|
||||||
return [self._convert(x) for x in ret]
|
|
||||||
else:
|
else:
|
||||||
return self._convert(ret)
|
# Strings seem to lead to a memory leak here if they're not wrapped in a
|
||||||
|
# list.
|
||||||
|
if isinstance(ret, six.binary_type):
|
||||||
|
ret = [ret]
|
||||||
|
# Ensures that we return either a single numpy array or a list of numpy
|
||||||
|
# arrays.
|
||||||
|
if isinstance(ret, (tuple, list)):
|
||||||
|
return [self._convert(x) for x in ret]
|
||||||
|
else:
|
||||||
|
return self._convert(ret)
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
"""Returns how many functions are currently registered."""
|
"""Returns how many functions are currently registered."""
|
||||||
@ -129,6 +163,86 @@ class CleanupFunc(object):
|
|||||||
_py_funcs.remove(self._token)
|
_py_funcs.remove(self._token)
|
||||||
|
|
||||||
|
|
||||||
|
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
|
||||||
|
"""See documentation for py_func and eager_py_func."""
|
||||||
|
|
||||||
|
is_list_or_tuple = False
|
||||||
|
if isinstance(Tout, (list, tuple)):
|
||||||
|
is_list_or_tuple = True
|
||||||
|
else:
|
||||||
|
Tout = [Tout]
|
||||||
|
|
||||||
|
if eager:
|
||||||
|
func = EagerFunc(func, Tout)
|
||||||
|
|
||||||
|
token = _py_funcs.insert(func)
|
||||||
|
# We tie the registered function's lifetime with the current default graph,
|
||||||
|
# i.e., when the current graph is destroyed, we remove its py funcs.
|
||||||
|
graph = ops.get_default_graph()
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
while isinstance(graph, function._FuncGraph):
|
||||||
|
# If the py_func was declared inside a _FuncGraph, its lifetime should be
|
||||||
|
# bound to that of the outer graph instead.
|
||||||
|
graph = graph._outer_graph
|
||||||
|
|
||||||
|
cleanup = CleanupFunc(token)
|
||||||
|
|
||||||
|
# TODO(zhifengc): Consider adding a Graph method to collect
|
||||||
|
# `cleanup` objects in one of its member.
|
||||||
|
if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"):
|
||||||
|
graph._cleanup_py_funcs_used_in_graph = []
|
||||||
|
|
||||||
|
# When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph
|
||||||
|
# will be destroyed and their __del__ will remove the 'token' from
|
||||||
|
# the funcs registry.
|
||||||
|
graph._cleanup_py_funcs_used_in_graph.append(cleanup)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if eager:
|
||||||
|
result = gen_script_ops._eager_py_func(
|
||||||
|
input=inp, token=token, Tout=Tout, name=name)
|
||||||
|
else:
|
||||||
|
if stateful:
|
||||||
|
result = gen_script_ops._py_func(
|
||||||
|
input=inp, token=token, Tout=Tout, name=name)
|
||||||
|
else:
|
||||||
|
result = gen_script_ops._py_func_stateless(
|
||||||
|
input=inp, token=token, Tout=Tout, name=name)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
return result if is_list_or_tuple else result[0]
|
||||||
|
|
||||||
|
|
||||||
|
def eager_py_func(func, inp, Tout, name=None):
|
||||||
|
"""Wraps a python function into a TensorFlow op.
|
||||||
|
|
||||||
|
When the returned op is executed, `func` is invoked with eager execution
|
||||||
|
enabled. Inputs are Tensor objects and func must return None or objects
|
||||||
|
that may be converted to Tensor objects.
|
||||||
|
|
||||||
|
This function has the same limitations as `py_func` with respect to
|
||||||
|
serialization and distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: A Python function which accepts a list of `Tensor` objects
|
||||||
|
having element types that match the corresponding `tf.Tensor` objects
|
||||||
|
in `inp` and returns a list of `Tensor` objects (or a single
|
||||||
|
`Tensor`, or `None`) having element types that match the
|
||||||
|
corresponding values in `Tout`.
|
||||||
|
inp: A list of `Tensor` objects.
|
||||||
|
Tout: A list or tuple of tensorflow data types or a single tensorflow data
|
||||||
|
type if there is only one, indicating what `func` returns; an empty list
|
||||||
|
if no value is returned (i.e., if the return value is `None`).
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
|
||||||
|
if `func` returns None.
|
||||||
|
"""
|
||||||
|
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
|
||||||
|
|
||||||
|
|
||||||
def py_func(func, inp, Tout, stateful=True, name=None):
|
def py_func(func, inp, Tout, stateful=True, name=None):
|
||||||
"""Wraps a python function and uses it as a TensorFlow op.
|
"""Wraps a python function and uses it as a TensorFlow op.
|
||||||
|
|
||||||
@ -182,46 +296,12 @@ def py_func(func, inp, Tout, stateful=True, name=None):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of `Tensor` or a single `Tensor` which `func` computes.
|
A list of `Tensor` or a single `Tensor` which `func` computes.
|
||||||
"""
|
"""
|
||||||
token = _py_funcs.insert(func)
|
return _internal_py_func(
|
||||||
# We tie the registered function's life-time with the current
|
func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
|
||||||
# default graph. I.e., when the current graph is destroyed, we
|
|
||||||
# should remove its py funcs.
|
|
||||||
g = ops.get_default_graph()
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
while isinstance(g, function._FuncGraph):
|
|
||||||
# If the py_func was declared inside a _FuncGraph, its lifetime should be
|
|
||||||
# bound to that of the outer graph instead.
|
|
||||||
g = g._outer_graph
|
|
||||||
|
|
||||||
cleanup = CleanupFunc(token)
|
|
||||||
|
|
||||||
# TODO(zhifengc): Consider adding a Graph method to collect
|
|
||||||
# `cleanup` objects in one of its member.
|
|
||||||
if not hasattr(g, "_cleanup_py_funcs_used_in_graph"):
|
|
||||||
g._cleanup_py_funcs_used_in_graph = []
|
|
||||||
|
|
||||||
# When g is destroyed, elements in _cleanup_py_funcs_used_in_graph
|
|
||||||
# will be destroyed and their __del__ will remove the 'token' from
|
|
||||||
# the funcs registry.
|
|
||||||
g._cleanup_py_funcs_used_in_graph.append(cleanup)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
if isinstance(Tout, (list, tuple)):
|
|
||||||
is_list_or_tuple = True
|
|
||||||
else:
|
|
||||||
Tout = [Tout]
|
|
||||||
is_list_or_tuple = False
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
if stateful:
|
|
||||||
result = gen_script_ops._py_func(
|
|
||||||
input=inp, token=token, Tout=Tout, name=name)
|
|
||||||
else:
|
|
||||||
result = gen_script_ops._py_func_stateless(
|
|
||||||
input=inp, token=token, Tout=Tout, name=name)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
return result if is_list_or_tuple else result[0]
|
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(akshayka): PyFuncs where the 'eager' attribute is set to True should be
|
||||||
|
# differentiable, i.e., the gradient of PyFunc should propagate Nones if the
|
||||||
|
# eager attribute is not set, and otherwise, it should return the gradient.
|
||||||
ops.NotDifferentiable("PyFunc")
|
ops.NotDifferentiable("PyFunc")
|
||||||
ops.NotDifferentiable("PyFuncStateless")
|
ops.NotDifferentiable("PyFuncStateless")
|
||||||
|
Loading…
Reference in New Issue
Block a user