Add tfe.py_func, a tf.py_func-like construct that wraps a Python function and executes it eagerly.
In particular, an EagerPyFunc op is added that wraps a Python function and executes it eagerly. The wrapped function should take Tensors as inputs and return Tensors as outputs. Because functions wrapped in an EagerPyFunc are executed eagerly, they can make use of TensorFlow operations. EagerPyFunc should be differentiable, in principle; a gradient will be implemented and registered in a future change. Once a gradient is implemented, tfe.py_func will probably be the easiest mechanism for experimenting with custom ops. tfe.py_func will also make it easier to translate python functions with side-effects into defun-able code. PiperOrigin-RevId: 178303818
This commit is contained in:
parent
2d4c29cd6a
commit
f37380b064
@ -19,6 +19,7 @@ py_library(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:numerics",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
|
@ -23,6 +23,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
||||
@@list_devices
|
||||
@@num_gpus
|
||||
|
||||
@@py_func
|
||||
@@defun
|
||||
@@implicit_gradients
|
||||
@@implicit_value_and_gradients
|
||||
@ -101,8 +102,10 @@ from tensorflow.python.framework.test_util import IsolateTest
|
||||
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
|
||||
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
|
||||
from tensorflow.python.ops.variable_scope import EagerVariableStore
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
py_func = script_ops.eager_py_func
|
||||
defun = function.defun
|
||||
implicit_gradients = backprop.implicit_grad
|
||||
implicit_value_and_gradients = backprop.implicit_val_and_grad
|
||||
|
@ -0,0 +1,8 @@
|
||||
op {
|
||||
graph_op_name: "EagerPyFunc"
|
||||
summary: "Eagerly executes a python function to compute func(input)->output. The"
|
||||
description: <<END
|
||||
semantics of the input, output, and attributes are the same as those for
|
||||
PyFunc.
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "EagerPyFunc"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -51,4 +51,18 @@ REGISTER_OP("PyFuncStateless")
|
||||
A stateless version of PyFunc.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("EagerPyFunc")
|
||||
.Input("input: Tin")
|
||||
.Output("output: Tout")
|
||||
.Attr("token: string")
|
||||
.Attr("Tin: list(type) >= 0")
|
||||
.Attr("Tout: list(type) >=0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Eagerly executes a python function to compute func(input)->output. The
|
||||
semantics of the input, output, and attributes are the same as those for
|
||||
PyFunc.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -280,10 +280,14 @@ cc_library(
|
||||
":ndarray_tensor_bridge",
|
||||
":numpy_lib",
|
||||
":py_util",
|
||||
":safe_ptr",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:script_ops_op_lib",
|
||||
"//tensorflow/python/eager:pywrap_tfe_lib",
|
||||
"//third_party/py/numpy:headers",
|
||||
"//util/python:python_headers",
|
||||
],
|
||||
|
@ -1645,6 +1645,8 @@ cuda_py_test(
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:function",
|
||||
],
|
||||
tags = ["no_windows"],
|
||||
)
|
||||
|
@ -23,82 +23,93 @@ from six.moves import queue
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class PyOpTest(test.TestCase):
|
||||
|
||||
def testBasic(self):
|
||||
|
||||
def my_func(x, y):
|
||||
def np_func(x, y):
|
||||
return np.sinh(x) + np.cosh(y)
|
||||
|
||||
# single type
|
||||
|
||||
def matmul(x, y):
|
||||
return math_ops.matmul(x, y)
|
||||
|
||||
|
||||
class PyFuncTest(test.TestCase):
|
||||
"""Encapsulates tests for py_func and eager_py_func."""
|
||||
|
||||
# ----- Tests for py_func -----
|
||||
def testSingleType(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(1.0, dtypes.float32)
|
||||
y = constant_op.constant(2.0, dtypes.float32)
|
||||
z = script_ops.py_func(my_func, [x, y], dtypes.float32)
|
||||
self.assertEqual(z.eval(), my_func(1.0, 2.0).astype(np.float32))
|
||||
z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
|
||||
self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
|
||||
|
||||
# scalar
|
||||
def testScalar(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(1.0, dtypes.float32)
|
||||
y = constant_op.constant(2.0, dtypes.float32)
|
||||
z = script_ops.py_func(my_func, [x, y], [dtypes.float32])
|
||||
self.assertEqual(z[0].eval(), my_func(1.0, 2.0).astype(np.float32))
|
||||
z = self.evaluate(
|
||||
script_ops.eager_py_func(np_func, [x, y], [dtypes.float32]))
|
||||
self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
|
||||
|
||||
# array
|
||||
def testArray(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant([1.0, 2.0], dtypes.float64)
|
||||
y = constant_op.constant([2.0, 3.0], dtypes.float64)
|
||||
z = script_ops.py_func(my_func, [x, y], [dtypes.float64])
|
||||
self.assertAllEqual(z[0].eval(),
|
||||
my_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
|
||||
z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
|
||||
self.assertAllEqual(z[0],
|
||||
np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
|
||||
|
||||
# a bit exotic type (complex64)
|
||||
def testComplexType(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(1 + 2j, dtypes.complex64)
|
||||
y = constant_op.constant(3 + 4j, dtypes.complex64)
|
||||
z, = script_ops.py_func(my_func, [x, y], [dtypes.complex64])
|
||||
self.assertAllClose(z.eval(), my_func(1 + 2j, 3 + 4j))
|
||||
z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
|
||||
self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
|
||||
|
||||
# a bit excotic function (rfft)
|
||||
def testRFFT(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
|
||||
|
||||
def rfft(x):
|
||||
return np.fft.rfft(x).astype(np.complex64)
|
||||
|
||||
y, = script_ops.py_func(rfft, [x], [dtypes.complex64])
|
||||
self.assertAllClose(y.eval(), np.fft.rfft([1., 2., 3., 4.]))
|
||||
y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64))
|
||||
self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
|
||||
|
||||
# returns a python literal.
|
||||
def testPythonLiteral(self):
|
||||
with self.test_session():
|
||||
|
||||
def literal(x):
|
||||
return 1.0 if x == 0.0 else 0.0
|
||||
return 1.0 if float(x) == 0.0 else 0.0
|
||||
|
||||
x = constant_op.constant(0.0, dtypes.float64)
|
||||
y, = script_ops.py_func(literal, [x], [dtypes.float64])
|
||||
self.assertAllClose(y.eval(), 1.0)
|
||||
y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64))
|
||||
self.assertAllClose(y, 1.0)
|
||||
|
||||
# returns a list
|
||||
def testList(self):
|
||||
with self.test_session():
|
||||
|
||||
def list_func(x):
|
||||
return [x, x + 1]
|
||||
|
||||
x = constant_op.constant(0.0, dtypes.float64)
|
||||
y, z = script_ops.py_func(list_func, [x], [dtypes.float64] * 2)
|
||||
self.assertAllClose(y.eval(), 0.0)
|
||||
self.assertAllClose(z.eval(), 1.0)
|
||||
y = self.evaluate(
|
||||
script_ops.py_func(list_func, [x], [dtypes.float64] * 2))
|
||||
self.assertAllClose(y, [0.0, 1.0])
|
||||
|
||||
def testTuple(self):
|
||||
# returns a tuple
|
||||
with self.test_session():
|
||||
|
||||
@ -106,17 +117,17 @@ class PyOpTest(test.TestCase):
|
||||
return x, x + 1
|
||||
|
||||
x = constant_op.constant(0.0, dtypes.float64)
|
||||
y, z = script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2)
|
||||
self.assertAllClose(y.eval(), 0.0)
|
||||
self.assertAllClose(z.eval(), 1.0)
|
||||
y = self.evaluate(
|
||||
script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2))
|
||||
self.assertAllClose(y, [0.0, 1.0])
|
||||
|
||||
# returns a tuple, Tout and inp a tuple
|
||||
with self.test_session():
|
||||
x = constant_op.constant(0.0, dtypes.float64)
|
||||
y, z = script_ops.py_func(tuple_func, (x,), (dtypes.float64,
|
||||
dtypes.float64))
|
||||
self.assertAllClose(y.eval(), 0.0)
|
||||
self.assertAllClose(z.eval(), 1.0)
|
||||
y = self.evaluate(
|
||||
script_ops.py_func(tuple_func, (x,),
|
||||
(dtypes.float64, dtypes.float64)))
|
||||
self.assertAllClose(y, [0.0, 1.0])
|
||||
|
||||
def testStrings(self):
|
||||
|
||||
@ -128,10 +139,12 @@ class PyOpTest(test.TestCase):
|
||||
|
||||
with self.test_session():
|
||||
x = constant_op.constant([b"hello", b"hi"], dtypes.string)
|
||||
y, = script_ops.py_func(read_fixed_length_numpy_strings, [],
|
||||
[dtypes.string])
|
||||
z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
|
||||
self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"])
|
||||
y = self.evaluate(
|
||||
script_ops.py_func(read_fixed_length_numpy_strings, [],
|
||||
dtypes.string))
|
||||
z = self.evaluate(
|
||||
script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
|
||||
self.assertAllEqual(z, [b"hello there", b"hi there"])
|
||||
|
||||
def testStringsAreConvertedToBytes(self):
|
||||
|
||||
@ -143,10 +156,12 @@ class PyOpTest(test.TestCase):
|
||||
|
||||
with self.test_session():
|
||||
x = constant_op.constant(["hello", "hi"], dtypes.string)
|
||||
y, = script_ops.py_func(read_fixed_length_numpy_strings, [],
|
||||
[dtypes.string])
|
||||
z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
|
||||
self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"])
|
||||
y = self.evaluate(
|
||||
script_ops.py_func(read_fixed_length_numpy_strings, [],
|
||||
dtypes.string))
|
||||
z = self.evaluate(
|
||||
script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
|
||||
self.assertAllEqual(z, [b"hello there", b"hi there"])
|
||||
|
||||
def testObjectArraysAreConvertedToBytes(self):
|
||||
|
||||
@ -186,16 +201,8 @@ class PyOpTest(test.TestCase):
|
||||
|
||||
def testNoInput(self):
|
||||
with self.test_session():
|
||||
x, = script_ops.py_func(lambda: 42.0, [], [dtypes.float64])
|
||||
self.assertAllClose(x.eval(), 42.0)
|
||||
|
||||
def testCleanup(self):
|
||||
for _ in xrange(1000):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
c = constant_op.constant([1.], dtypes.float32)
|
||||
_ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
|
||||
self.assertTrue(script_ops._py_funcs.size() < 100)
|
||||
x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
|
||||
self.assertAllClose(x, 42.0)
|
||||
|
||||
def testAlias(self):
|
||||
with self.test_session():
|
||||
@ -319,10 +326,10 @@ class PyOpTest(test.TestCase):
|
||||
def value(self):
|
||||
return self._value
|
||||
|
||||
with self.test_session() as sess:
|
||||
with self.test_session():
|
||||
s = State()
|
||||
op = s.increment(constant_op.constant(2, dtypes.int64))
|
||||
ret = sess.run(op)
|
||||
ret = self.evaluate(op)
|
||||
self.assertIsNone(ret)
|
||||
self.assertAllEqual([3], s.value)
|
||||
|
||||
@ -336,15 +343,24 @@ class PyOpTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
self.assertEqual(sess.run(f), [])
|
||||
|
||||
def _testExceptionHandling(self, py_exp, tf_exp):
|
||||
def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
|
||||
|
||||
def raise_exception():
|
||||
raise py_exp("blah") # pylint: disable=not-callable
|
||||
|
||||
f = script_ops.py_func(raise_exception, [], [])
|
||||
with self.test_session() as sess:
|
||||
if eager:
|
||||
if context.in_eager_mode():
|
||||
with self.assertRaisesRegexp(tf_exp, "blah"):
|
||||
sess.run(f)
|
||||
f = script_ops.eager_py_func(raise_exception, [], [])
|
||||
return
|
||||
else:
|
||||
f = script_ops.eager_py_func(raise_exception, [], [])
|
||||
else:
|
||||
f = script_ops.py_func(raise_exception, [], [])
|
||||
|
||||
with self.test_session():
|
||||
with self.assertRaisesRegexp(tf_exp, "blah"):
|
||||
self.evaluate(f)
|
||||
|
||||
def testExceptionHandling(self):
|
||||
self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
|
||||
@ -358,6 +374,89 @@ class PyOpTest(test.TestCase):
|
||||
|
||||
self._testExceptionHandling(WeirdError, errors.UnknownError)
|
||||
|
||||
# ----- Tests shared by py_func and eager_py_func -----
|
||||
def testCleanup(self):
|
||||
for _ in xrange(1000):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
c = constant_op.constant([1.], dtypes.float32)
|
||||
_ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
|
||||
_ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
|
||||
self.assertTrue(script_ops._py_funcs.size() < 100)
|
||||
|
||||
# ----- Tests for eager_py_func -----
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testEagerSingleOutputInt32(self):
|
||||
a = array_ops.ones((3, 3), dtype=dtypes.int32)
|
||||
x = array_ops.ones((3, 1), dtype=dtypes.int32)
|
||||
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
|
||||
with self.test_session():
|
||||
ret = self.evaluate(output)
|
||||
self.assertAllEqual(ret, [[3], [3], [3]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testEagerSingleOutputFloat32(self):
|
||||
a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
||||
x = array_ops.ones((3, 1), dtype=dtypes.float32)
|
||||
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
|
||||
with self.test_session():
|
||||
ret = self.evaluate(output)
|
||||
self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testEagerArrayOutput(self):
|
||||
a = array_ops.ones((3, 3), dtype=dtypes.int32)
|
||||
x = array_ops.ones((3, 1), dtype=dtypes.int32)
|
||||
output = script_ops.eager_py_func(
|
||||
lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.int32])
|
||||
|
||||
with self.test_session():
|
||||
ret = self.evaluate(output)
|
||||
self.assertAllEqual(ret, [[[3], [3], [3]]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testEagerReturnNone(self):
|
||||
|
||||
def no_return_value():
|
||||
return
|
||||
|
||||
output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
|
||||
ret = self.evaluate(output)
|
||||
if context.in_eager_mode():
|
||||
self.assertEquals(len(ret), 0)
|
||||
else:
|
||||
self.assertIsNone(ret)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testEagerPyFuncInDefun(self):
|
||||
|
||||
def wrapper():
|
||||
a = array_ops.ones((3, 3), dtype=dtypes.int32)
|
||||
x = array_ops.ones((3, 1), dtype=dtypes.int32)
|
||||
return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
|
||||
|
||||
wrapped = function.defun(wrapper)
|
||||
ret = self.evaluate(wrapped())
|
||||
self.assertAllEqual(ret, [[3], [3], [3]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testEagerExceptionHandling(self):
|
||||
self._testExceptionHandling(
|
||||
ValueError, errors.InvalidArgumentError, eager=True)
|
||||
self._testExceptionHandling(
|
||||
TypeError, errors.InvalidArgumentError, eager=True)
|
||||
self._testExceptionHandling(
|
||||
StopIteration, errors.OutOfRangeError, eager=True)
|
||||
self._testExceptionHandling(
|
||||
MemoryError, errors.ResourceExhaustedError, eager=True)
|
||||
self._testExceptionHandling(
|
||||
NotImplementedError, errors.UnimplementedError, eager=True)
|
||||
|
||||
class WeirdError(Exception):
|
||||
pass
|
||||
|
||||
self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
#include <array>
|
||||
|
||||
#include "numpy/arrayobject.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -25,8 +27,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
|
||||
#include "tensorflow/python/lib/core/py_util.h"
|
||||
#include "tensorflow/python/lib/core/safe_ptr.h"
|
||||
#include <Python.h>
|
||||
|
||||
namespace tensorflow {
|
||||
@ -48,6 +52,9 @@ struct PyCall {
|
||||
// with this "token".
|
||||
string token;
|
||||
|
||||
// True if the call is associated with an EagerPyFunc.
|
||||
bool eager;
|
||||
|
||||
// Inputs and outputs of this function invocation.
|
||||
std::vector<Tensor> ins;
|
||||
std::vector<Tensor> out;
|
||||
@ -55,19 +62,26 @@ struct PyCall {
|
||||
|
||||
// Givens the 'call', prepares the token and inputs as a python tuple
|
||||
// that is appropriate for calling the trampoline.
|
||||
Status MakeArgTuple(PyCall* call, PyObject** tuple) {
|
||||
Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
|
||||
int64 n = call->ins.size();
|
||||
PyObject* lst = PyList_New(n);
|
||||
CHECK(lst);
|
||||
for (int64 i = 0; i < n; ++i) {
|
||||
PyObject* arg = nullptr;
|
||||
const Tensor& t = call->ins[i];
|
||||
PyObject* a = nullptr;
|
||||
Status s = ConvertTensorToNdarray(t, &a);
|
||||
if (call->eager) {
|
||||
arg = EagerTensorFromHandle(TFE_NewTensorHandle(t));
|
||||
if (arg == nullptr) {
|
||||
return errors::Internal("Unable to procure EagerTensor from Tensor.");
|
||||
}
|
||||
} else {
|
||||
Status s = ConvertTensorToNdarray(t, &arg);
|
||||
if (!s.ok()) {
|
||||
Py_DECREF(lst);
|
||||
return s;
|
||||
}
|
||||
PyList_SetItem(lst, i, a);
|
||||
}
|
||||
PyList_SetItem(lst, i, arg);
|
||||
}
|
||||
*tuple = Py_BuildValue("(sN)", call->token.c_str(), lst);
|
||||
CHECK(*tuple);
|
||||
@ -133,6 +147,18 @@ bool IsSingleNone(PyObject* obj) {
|
||||
return item == Py_None;
|
||||
}
|
||||
|
||||
// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
|
||||
Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
|
||||
Tensor* output_tensor,
|
||||
TF_Status* tf_status) {
|
||||
// TODO(akshayka): Lift the restriction requiring output tensors to
|
||||
// lie in host memory; EagerPyFunc should be able to dispatch ops on GPU
|
||||
// tensors, so we should eventually implement a GPU kernel for EagerPyFunc.
|
||||
*output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||
EagerTensor_Handle(eager_tensor), tf_status);
|
||||
return StatusFromTF_Status(tf_status);
|
||||
}
|
||||
|
||||
// Calls the registered py function through the trampoline.
|
||||
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
||||
*out_log_on_error = true;
|
||||
@ -172,21 +198,37 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Process the return values and converts them to tf Tensors.
|
||||
// Process the return values and convert them to TF Tensors.
|
||||
Status s;
|
||||
if (PyList_Check(result)) {
|
||||
// 'result' is a list.
|
||||
call->out.clear();
|
||||
for (int i = 0; i < PyList_Size(result); ++i) {
|
||||
Tensor t;
|
||||
if (call->eager) {
|
||||
auto tf_status = tensorflow::make_safe(TF_NewStatus());
|
||||
s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t,
|
||||
tf_status.get());
|
||||
} else {
|
||||
s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
|
||||
}
|
||||
|
||||
if (!s.ok()) {
|
||||
break;
|
||||
}
|
||||
call->out.push_back(t);
|
||||
}
|
||||
} else if (EagerTensor_CheckExact(result) || result == Py_None) {
|
||||
DCHECK(call->eager);
|
||||
Tensor t;
|
||||
if (result != Py_None) {
|
||||
auto tf_status = tensorflow::make_safe(TF_NewStatus());
|
||||
s = ExtractTensorFromEagerTensor(result, &t, tf_status.get());
|
||||
if (s.ok()) {
|
||||
call->out.push_back(t);
|
||||
}
|
||||
}
|
||||
} else if (PyArray_Check(result)) {
|
||||
// 'result' is a single ndarray.
|
||||
DCHECK(!call->eager);
|
||||
if (!IsSingleNone(result)) {
|
||||
Tensor t;
|
||||
s = ConvertNdarrayToTensor(result, &t);
|
||||
@ -375,11 +417,13 @@ class PyFuncOp : public OpKernel {
|
||||
public:
|
||||
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
|
||||
eager_ = type_string() == "EagerPyFunc";
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
PyCall call;
|
||||
call.token = token_;
|
||||
call.eager = eager_;
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
call.ins.push_back(ctx->input(i));
|
||||
}
|
||||
@ -418,9 +462,15 @@ class PyFuncOp : public OpKernel {
|
||||
private:
|
||||
string token_;
|
||||
|
||||
// True if and only if this op should execute the python function eagerly,
|
||||
// i.e., if and only if the eager attribute is set.
|
||||
bool eager_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -24,21 +24,27 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Called by py code on initialization.
|
||||
// Called by python code on initialization.
|
||||
//
|
||||
// "trampoline" must represent a python function which has the
|
||||
// following signature:
|
||||
// (string, list(ndarray)) -> ndarray | list(ndarray) | python scalar
|
||||
// (string, list(ndarray)) | (string, list(EagerTensor)) ->
|
||||
// ndarray | list(ndarray) | python scalar |
|
||||
// EagerTensor | list(EagerTensor) | None
|
||||
//
|
||||
// The trampoline takes two arguments, the first is a string token
|
||||
// used by the python frontend's dispatching logic; the second is a
|
||||
// list of numpy ndarrays.
|
||||
// list of numpy ndarrays or EagerTensor objects. It can return a
|
||||
// single numpy ndarray, a list of numpy ndarrays, a python scalar, an
|
||||
// EagerTensor, a list of EagerTensors, or None.
|
||||
//
|
||||
// The trampoline can return a single numpy ndarray, a list of numpy
|
||||
// ndarrays, or a simply python scalar. The C++ runtime converts them,
|
||||
// if supported, back to Tensor objects.
|
||||
// PyFunc requires inputs and outputs to be ndarrays. EagerPyFunc requires
|
||||
// inputs to be a list of EagerTensors and outputs to be an EagerTensor, a list
|
||||
// of EagerTensors, or None.
|
||||
//
|
||||
// This is called by script_ops.py during its module initialization.
|
||||
// The C++ runtime converts outputs back to Tensor objects.
|
||||
//
|
||||
// This function is called by script_ops.py during its module initialization.
|
||||
//
|
||||
// TODO(zhifengc): Support distributed runtime.
|
||||
void InitializePyTrampoline(PyObject* trampoline);
|
||||
|
@ -341,6 +341,7 @@ TruncatedNormal
|
||||
# script_ops
|
||||
PyFunc
|
||||
PyFuncStateless
|
||||
EagerPyFunc
|
||||
|
||||
# sdca_ops
|
||||
|
||||
|
@ -29,11 +29,41 @@ import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_script_ops
|
||||
|
||||
|
||||
class EagerFunc(object):
|
||||
"""A wrapper for a function owned by an EagerPyFunc."""
|
||||
|
||||
def __init__(self, func, Tout):
|
||||
"""Constructs an EagerFunc.
|
||||
|
||||
Args:
|
||||
func: The function to wrap.
|
||||
Tout: A list of datatypes for the output; an empty list if the output is
|
||||
None.
|
||||
"""
|
||||
self._func = func
|
||||
self._out_dtypes = Tout
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Passes args, kwargs to `self._func`, which is executed eagerly."""
|
||||
with context.eager_mode():
|
||||
ret = self._func(*args, **kwargs)
|
||||
if isinstance(ret, (tuple, list)):
|
||||
return [
|
||||
ops.convert_to_tensor(x, dtype=dtype)
|
||||
for (x, dtype) in zip(ret, self._out_dtypes)
|
||||
]
|
||||
elif ret is None:
|
||||
return ret
|
||||
else:
|
||||
return ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])
|
||||
|
||||
|
||||
class FuncRegistry(object):
|
||||
"""A helper class to keep track of registered py functions.
|
||||
|
||||
@ -91,6 +121,10 @@ class FuncRegistry(object):
|
||||
if func is None:
|
||||
raise ValueError("callback %s is not found" % token)
|
||||
ret = func(*args)
|
||||
|
||||
if isinstance(func, EagerFunc):
|
||||
return ret
|
||||
else:
|
||||
# Strings seem to lead to a memory leak here if they're not wrapped in a
|
||||
# list.
|
||||
if isinstance(ret, six.binary_type):
|
||||
@ -129,6 +163,86 @@ class CleanupFunc(object):
|
||||
_py_funcs.remove(self._token)
|
||||
|
||||
|
||||
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
|
||||
"""See documentation for py_func and eager_py_func."""
|
||||
|
||||
is_list_or_tuple = False
|
||||
if isinstance(Tout, (list, tuple)):
|
||||
is_list_or_tuple = True
|
||||
else:
|
||||
Tout = [Tout]
|
||||
|
||||
if eager:
|
||||
func = EagerFunc(func, Tout)
|
||||
|
||||
token = _py_funcs.insert(func)
|
||||
# We tie the registered function's lifetime with the current default graph,
|
||||
# i.e., when the current graph is destroyed, we remove its py funcs.
|
||||
graph = ops.get_default_graph()
|
||||
|
||||
# pylint: disable=protected-access
|
||||
while isinstance(graph, function._FuncGraph):
|
||||
# If the py_func was declared inside a _FuncGraph, its lifetime should be
|
||||
# bound to that of the outer graph instead.
|
||||
graph = graph._outer_graph
|
||||
|
||||
cleanup = CleanupFunc(token)
|
||||
|
||||
# TODO(zhifengc): Consider adding a Graph method to collect
|
||||
# `cleanup` objects in one of its member.
|
||||
if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"):
|
||||
graph._cleanup_py_funcs_used_in_graph = []
|
||||
|
||||
# When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph
|
||||
# will be destroyed and their __del__ will remove the 'token' from
|
||||
# the funcs registry.
|
||||
graph._cleanup_py_funcs_used_in_graph.append(cleanup)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if eager:
|
||||
result = gen_script_ops._eager_py_func(
|
||||
input=inp, token=token, Tout=Tout, name=name)
|
||||
else:
|
||||
if stateful:
|
||||
result = gen_script_ops._py_func(
|
||||
input=inp, token=token, Tout=Tout, name=name)
|
||||
else:
|
||||
result = gen_script_ops._py_func_stateless(
|
||||
input=inp, token=token, Tout=Tout, name=name)
|
||||
# pylint: enable=protected-access
|
||||
return result if is_list_or_tuple else result[0]
|
||||
|
||||
|
||||
def eager_py_func(func, inp, Tout, name=None):
|
||||
"""Wraps a python function into a TensorFlow op.
|
||||
|
||||
When the returned op is executed, `func` is invoked with eager execution
|
||||
enabled. Inputs are Tensor objects and func must return None or objects
|
||||
that may be converted to Tensor objects.
|
||||
|
||||
This function has the same limitations as `py_func` with respect to
|
||||
serialization and distribution.
|
||||
|
||||
Args:
|
||||
func: A Python function which accepts a list of `Tensor` objects
|
||||
having element types that match the corresponding `tf.Tensor` objects
|
||||
in `inp` and returns a list of `Tensor` objects (or a single
|
||||
`Tensor`, or `None`) having element types that match the
|
||||
corresponding values in `Tout`.
|
||||
inp: A list of `Tensor` objects.
|
||||
Tout: A list or tuple of tensorflow data types or a single tensorflow data
|
||||
type if there is only one, indicating what `func` returns; an empty list
|
||||
if no value is returned (i.e., if the return value is `None`).
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
|
||||
if `func` returns None.
|
||||
"""
|
||||
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
|
||||
|
||||
|
||||
def py_func(func, inp, Tout, stateful=True, name=None):
|
||||
"""Wraps a python function and uses it as a TensorFlow op.
|
||||
|
||||
@ -182,46 +296,12 @@ def py_func(func, inp, Tout, stateful=True, name=None):
|
||||
Returns:
|
||||
A list of `Tensor` or a single `Tensor` which `func` computes.
|
||||
"""
|
||||
token = _py_funcs.insert(func)
|
||||
# We tie the registered function's life-time with the current
|
||||
# default graph. I.e., when the current graph is destroyed, we
|
||||
# should remove its py funcs.
|
||||
g = ops.get_default_graph()
|
||||
|
||||
# pylint: disable=protected-access
|
||||
while isinstance(g, function._FuncGraph):
|
||||
# If the py_func was declared inside a _FuncGraph, its lifetime should be
|
||||
# bound to that of the outer graph instead.
|
||||
g = g._outer_graph
|
||||
|
||||
cleanup = CleanupFunc(token)
|
||||
|
||||
# TODO(zhifengc): Consider adding a Graph method to collect
|
||||
# `cleanup` objects in one of its member.
|
||||
if not hasattr(g, "_cleanup_py_funcs_used_in_graph"):
|
||||
g._cleanup_py_funcs_used_in_graph = []
|
||||
|
||||
# When g is destroyed, elements in _cleanup_py_funcs_used_in_graph
|
||||
# will be destroyed and their __del__ will remove the 'token' from
|
||||
# the funcs registry.
|
||||
g._cleanup_py_funcs_used_in_graph.append(cleanup)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
if isinstance(Tout, (list, tuple)):
|
||||
is_list_or_tuple = True
|
||||
else:
|
||||
Tout = [Tout]
|
||||
is_list_or_tuple = False
|
||||
# pylint: disable=protected-access
|
||||
if stateful:
|
||||
result = gen_script_ops._py_func(
|
||||
input=inp, token=token, Tout=Tout, name=name)
|
||||
else:
|
||||
result = gen_script_ops._py_func_stateless(
|
||||
input=inp, token=token, Tout=Tout, name=name)
|
||||
# pylint: enable=protected-access
|
||||
return result if is_list_or_tuple else result[0]
|
||||
return _internal_py_func(
|
||||
func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
|
||||
|
||||
|
||||
# TODO(akshayka): PyFuncs where the 'eager' attribute is set to True should be
|
||||
# differentiable, i.e., the gradient of PyFunc should propagate Nones if the
|
||||
# eager attribute is not set, and otherwise, it should return the gradient.
|
||||
ops.NotDifferentiable("PyFunc")
|
||||
ops.NotDifferentiable("PyFuncStateless")
|
||||
|
Loading…
Reference in New Issue
Block a user