Add tfe.py_func, a tf.py_func-like construct that wraps a Python function and executes it eagerly.

In particular, an EagerPyFunc op is added that wraps a Python function and executes it eagerly. The wrapped function should take Tensors as inputs and return Tensors as outputs. Because functions wrapped in an EagerPyFunc are executed eagerly, they can make use of TensorFlow operations.

EagerPyFunc should be differentiable, in principle; a gradient will be implemented and registered in a future change. Once a gradient is implemented, tfe.py_func will probably be the easiest mechanism for experimenting with custom ops.

tfe.py_func will also make it easier to translate python functions with side-effects into defun-able code.

PiperOrigin-RevId: 178303818
This commit is contained in:
Akshay Agrawal 2017-12-07 15:18:46 -08:00 committed by TensorFlower Gardener
parent 2d4c29cd6a
commit f37380b064
12 changed files with 399 additions and 127 deletions

View File

@ -19,6 +19,7 @@ py_library(
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
"//tensorflow/python:numerics", "//tensorflow/python:numerics",
"//tensorflow/python:resource_variable_ops", "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python/eager:backprop", "//tensorflow/python/eager:backprop",

View File

@ -23,6 +23,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@list_devices @@list_devices
@@num_gpus @@num_gpus
@@py_func
@@defun @@defun
@@implicit_gradients @@implicit_gradients
@@implicit_value_and_gradients @@implicit_value_and_gradients
@ -101,8 +102,10 @@ from tensorflow.python.framework.test_util import IsolateTest
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops.variable_scope import EagerVariableStore
from tensorflow.python.ops import script_ops
from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.all_util import remove_undocumented
py_func = script_ops.eager_py_func
defun = function.defun defun = function.defun
implicit_gradients = backprop.implicit_grad implicit_gradients = backprop.implicit_grad
implicit_value_and_gradients = backprop.implicit_val_and_grad implicit_value_and_gradients = backprop.implicit_val_and_grad

View File

@ -0,0 +1,8 @@
op {
graph_op_name: "EagerPyFunc"
summary: "Eagerly executes a python function to compute func(input)->output. The"
description: <<END
semantics of the input, output, and attributes are the same as those for
PyFunc.
END
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "EagerPyFunc"
visibility: HIDDEN
}

View File

@ -51,4 +51,18 @@ REGISTER_OP("PyFuncStateless")
A stateless version of PyFunc. A stateless version of PyFunc.
)doc"); )doc");
REGISTER_OP("EagerPyFunc")
.Input("input: Tin")
.Output("output: Tout")
.Attr("token: string")
.Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type) >=0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Eagerly executes a python function to compute func(input)->output. The
semantics of the input, output, and attributes are the same as those for
PyFunc.
)doc");
} // namespace tensorflow } // namespace tensorflow

View File

@ -280,10 +280,14 @@ cc_library(
":ndarray_tensor_bridge", ":ndarray_tensor_bridge",
":numpy_lib", ":numpy_lib",
":py_util", ":py_util",
":safe_ptr",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:script_ops_op_lib", "//tensorflow/core:script_ops_op_lib",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//third_party/py/numpy:headers", "//third_party/py/numpy:headers",
"//util/python:python_headers", "//util/python:python_headers",
], ],

View File

@ -1645,6 +1645,8 @@ cuda_py_test(
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:script_ops", "//tensorflow/python:script_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
], ],
tags = ["no_windows"], tags = ["no_windows"],
) )

View File

@ -23,82 +23,93 @@ from six.moves import queue
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import session as session_lib from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
class PyOpTest(test.TestCase): def np_func(x, y):
return np.sinh(x) + np.cosh(y)
def testBasic(self):
def my_func(x, y): def matmul(x, y):
return np.sinh(x) + np.cosh(y) return math_ops.matmul(x, y)
# single type
class PyFuncTest(test.TestCase):
"""Encapsulates tests for py_func and eager_py_func."""
# ----- Tests for py_func -----
def testSingleType(self):
with self.test_session(): with self.test_session():
x = constant_op.constant(1.0, dtypes.float32) x = constant_op.constant(1.0, dtypes.float32)
y = constant_op.constant(2.0, dtypes.float32) y = constant_op.constant(2.0, dtypes.float32)
z = script_ops.py_func(my_func, [x, y], dtypes.float32) z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
self.assertEqual(z.eval(), my_func(1.0, 2.0).astype(np.float32)) self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
# scalar def testScalar(self):
with self.test_session(): with self.test_session():
x = constant_op.constant(1.0, dtypes.float32) x = constant_op.constant(1.0, dtypes.float32)
y = constant_op.constant(2.0, dtypes.float32) y = constant_op.constant(2.0, dtypes.float32)
z = script_ops.py_func(my_func, [x, y], [dtypes.float32]) z = self.evaluate(
self.assertEqual(z[0].eval(), my_func(1.0, 2.0).astype(np.float32)) script_ops.eager_py_func(np_func, [x, y], [dtypes.float32]))
self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
# array def testArray(self):
with self.test_session(): with self.test_session():
x = constant_op.constant([1.0, 2.0], dtypes.float64) x = constant_op.constant([1.0, 2.0], dtypes.float64)
y = constant_op.constant([2.0, 3.0], dtypes.float64) y = constant_op.constant([2.0, 3.0], dtypes.float64)
z = script_ops.py_func(my_func, [x, y], [dtypes.float64]) z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
self.assertAllEqual(z[0].eval(), self.assertAllEqual(z[0],
my_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64)) np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
# a bit exotic type (complex64) def testComplexType(self):
with self.test_session(): with self.test_session():
x = constant_op.constant(1 + 2j, dtypes.complex64) x = constant_op.constant(1 + 2j, dtypes.complex64)
y = constant_op.constant(3 + 4j, dtypes.complex64) y = constant_op.constant(3 + 4j, dtypes.complex64)
z, = script_ops.py_func(my_func, [x, y], [dtypes.complex64]) z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
self.assertAllClose(z.eval(), my_func(1 + 2j, 3 + 4j)) self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
# a bit excotic function (rfft) def testRFFT(self):
with self.test_session(): with self.test_session():
x = constant_op.constant([1., 2., 3., 4.], dtypes.float32) x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
def rfft(x): def rfft(x):
return np.fft.rfft(x).astype(np.complex64) return np.fft.rfft(x).astype(np.complex64)
y, = script_ops.py_func(rfft, [x], [dtypes.complex64]) y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64))
self.assertAllClose(y.eval(), np.fft.rfft([1., 2., 3., 4.])) self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
# returns a python literal. def testPythonLiteral(self):
with self.test_session(): with self.test_session():
def literal(x): def literal(x):
return 1.0 if x == 0.0 else 0.0 return 1.0 if float(x) == 0.0 else 0.0
x = constant_op.constant(0.0, dtypes.float64) x = constant_op.constant(0.0, dtypes.float64)
y, = script_ops.py_func(literal, [x], [dtypes.float64]) y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64))
self.assertAllClose(y.eval(), 1.0) self.assertAllClose(y, 1.0)
# returns a list def testList(self):
with self.test_session(): with self.test_session():
def list_func(x): def list_func(x):
return [x, x + 1] return [x, x + 1]
x = constant_op.constant(0.0, dtypes.float64) x = constant_op.constant(0.0, dtypes.float64)
y, z = script_ops.py_func(list_func, [x], [dtypes.float64] * 2) y = self.evaluate(
self.assertAllClose(y.eval(), 0.0) script_ops.py_func(list_func, [x], [dtypes.float64] * 2))
self.assertAllClose(z.eval(), 1.0) self.assertAllClose(y, [0.0, 1.0])
def testTuple(self):
# returns a tuple # returns a tuple
with self.test_session(): with self.test_session():
@ -106,17 +117,17 @@ class PyOpTest(test.TestCase):
return x, x + 1 return x, x + 1
x = constant_op.constant(0.0, dtypes.float64) x = constant_op.constant(0.0, dtypes.float64)
y, z = script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2) y = self.evaluate(
self.assertAllClose(y.eval(), 0.0) script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2))
self.assertAllClose(z.eval(), 1.0) self.assertAllClose(y, [0.0, 1.0])
# returns a tuple, Tout and inp a tuple # returns a tuple, Tout and inp a tuple
with self.test_session(): with self.test_session():
x = constant_op.constant(0.0, dtypes.float64) x = constant_op.constant(0.0, dtypes.float64)
y, z = script_ops.py_func(tuple_func, (x,), (dtypes.float64, y = self.evaluate(
dtypes.float64)) script_ops.py_func(tuple_func, (x,),
self.assertAllClose(y.eval(), 0.0) (dtypes.float64, dtypes.float64)))
self.assertAllClose(z.eval(), 1.0) self.assertAllClose(y, [0.0, 1.0])
def testStrings(self): def testStrings(self):
@ -128,10 +139,12 @@ class PyOpTest(test.TestCase):
with self.test_session(): with self.test_session():
x = constant_op.constant([b"hello", b"hi"], dtypes.string) x = constant_op.constant([b"hello", b"hi"], dtypes.string)
y, = script_ops.py_func(read_fixed_length_numpy_strings, [], y = self.evaluate(
[dtypes.string]) script_ops.py_func(read_fixed_length_numpy_strings, [],
z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string]) dtypes.string))
self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"]) z = self.evaluate(
script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
self.assertAllEqual(z, [b"hello there", b"hi there"])
def testStringsAreConvertedToBytes(self): def testStringsAreConvertedToBytes(self):
@ -143,10 +156,12 @@ class PyOpTest(test.TestCase):
with self.test_session(): with self.test_session():
x = constant_op.constant(["hello", "hi"], dtypes.string) x = constant_op.constant(["hello", "hi"], dtypes.string)
y, = script_ops.py_func(read_fixed_length_numpy_strings, [], y = self.evaluate(
[dtypes.string]) script_ops.py_func(read_fixed_length_numpy_strings, [],
z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string]) dtypes.string))
self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"]) z = self.evaluate(
script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
self.assertAllEqual(z, [b"hello there", b"hi there"])
def testObjectArraysAreConvertedToBytes(self): def testObjectArraysAreConvertedToBytes(self):
@ -186,16 +201,8 @@ class PyOpTest(test.TestCase):
def testNoInput(self): def testNoInput(self):
with self.test_session(): with self.test_session():
x, = script_ops.py_func(lambda: 42.0, [], [dtypes.float64]) x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
self.assertAllClose(x.eval(), 42.0) self.assertAllClose(x, 42.0)
def testCleanup(self):
for _ in xrange(1000):
g = ops.Graph()
with g.as_default():
c = constant_op.constant([1.], dtypes.float32)
_ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
self.assertTrue(script_ops._py_funcs.size() < 100)
def testAlias(self): def testAlias(self):
with self.test_session(): with self.test_session():
@ -242,8 +249,8 @@ class PyOpTest(test.TestCase):
# Create a numpy array aliasing a tensor and a tensor aliasing this array # Create a numpy array aliasing a tensor and a tensor aliasing this array
z, = script_ops.py_func(ident, [p], [dtypes.float32]) z, = script_ops.py_func(ident, [p], [dtypes.float32])
z += 0.0 # Makes sure we release the tensor aliasing the numpy array x[0] z += 0.0 # Makes sure we release the tensor aliasing the numpy array x[0]
# above instead of using its memory as the return value of # above instead of using its memory as the return value of
# session.run # session.run
self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]})) self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
def testStateful(self): def testStateful(self):
@ -319,10 +326,10 @@ class PyOpTest(test.TestCase):
def value(self): def value(self):
return self._value return self._value
with self.test_session() as sess: with self.test_session():
s = State() s = State()
op = s.increment(constant_op.constant(2, dtypes.int64)) op = s.increment(constant_op.constant(2, dtypes.int64))
ret = sess.run(op) ret = self.evaluate(op)
self.assertIsNone(ret) self.assertIsNone(ret)
self.assertAllEqual([3], s.value) self.assertAllEqual([3], s.value)
@ -336,15 +343,24 @@ class PyOpTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
self.assertEqual(sess.run(f), []) self.assertEqual(sess.run(f), [])
def _testExceptionHandling(self, py_exp, tf_exp): def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
def raise_exception(): def raise_exception():
raise py_exp("blah") # pylint: disable=not-callable raise py_exp("blah") # pylint: disable=not-callable
f = script_ops.py_func(raise_exception, [], []) if eager:
with self.test_session() as sess: if context.in_eager_mode():
with self.assertRaisesRegexp(tf_exp, "blah"):
f = script_ops.eager_py_func(raise_exception, [], [])
return
else:
f = script_ops.eager_py_func(raise_exception, [], [])
else:
f = script_ops.py_func(raise_exception, [], [])
with self.test_session():
with self.assertRaisesRegexp(tf_exp, "blah"): with self.assertRaisesRegexp(tf_exp, "blah"):
sess.run(f) self.evaluate(f)
def testExceptionHandling(self): def testExceptionHandling(self):
self._testExceptionHandling(ValueError, errors.InvalidArgumentError) self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
@ -358,6 +374,89 @@ class PyOpTest(test.TestCase):
self._testExceptionHandling(WeirdError, errors.UnknownError) self._testExceptionHandling(WeirdError, errors.UnknownError)
# ----- Tests shared by py_func and eager_py_func -----
def testCleanup(self):
for _ in xrange(1000):
g = ops.Graph()
with g.as_default():
c = constant_op.constant([1.], dtypes.float32)
_ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
_ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
self.assertTrue(script_ops._py_funcs.size() < 100)
# ----- Tests for eager_py_func -----
@test_util.run_in_graph_and_eager_modes()
def testEagerSingleOutputInt32(self):
a = array_ops.ones((3, 3), dtype=dtypes.int32)
x = array_ops.ones((3, 1), dtype=dtypes.int32)
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
with self.test_session():
ret = self.evaluate(output)
self.assertAllEqual(ret, [[3], [3], [3]])
@test_util.run_in_graph_and_eager_modes()
def testEagerSingleOutputFloat32(self):
a = array_ops.ones((3, 3), dtype=dtypes.float32)
x = array_ops.ones((3, 1), dtype=dtypes.float32)
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
with self.test_session():
ret = self.evaluate(output)
self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
@test_util.run_in_graph_and_eager_modes()
def testEagerArrayOutput(self):
a = array_ops.ones((3, 3), dtype=dtypes.int32)
x = array_ops.ones((3, 1), dtype=dtypes.int32)
output = script_ops.eager_py_func(
lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.int32])
with self.test_session():
ret = self.evaluate(output)
self.assertAllEqual(ret, [[[3], [3], [3]]])
@test_util.run_in_graph_and_eager_modes()
def testEagerReturnNone(self):
def no_return_value():
return
output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
ret = self.evaluate(output)
if context.in_eager_mode():
self.assertEquals(len(ret), 0)
else:
self.assertIsNone(ret)
@test_util.run_in_graph_and_eager_modes()
def testEagerPyFuncInDefun(self):
def wrapper():
a = array_ops.ones((3, 3), dtype=dtypes.int32)
x = array_ops.ones((3, 1), dtype=dtypes.int32)
return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
wrapped = function.defun(wrapper)
ret = self.evaluate(wrapped())
self.assertAllEqual(ret, [[3], [3], [3]])
@test_util.run_in_graph_and_eager_modes()
def testEagerExceptionHandling(self):
self._testExceptionHandling(
ValueError, errors.InvalidArgumentError, eager=True)
self._testExceptionHandling(
TypeError, errors.InvalidArgumentError, eager=True)
self._testExceptionHandling(
StopIteration, errors.OutOfRangeError, eager=True)
self._testExceptionHandling(
MemoryError, errors.ResourceExhaustedError, eager=True)
self._testExceptionHandling(
NotImplementedError, errors.UnimplementedError, eager=True)
class WeirdError(Exception):
pass
self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <array> #include <array>
#include "numpy/arrayobject.h" #include "numpy/arrayobject.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
@ -25,8 +27,10 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/py_util.h" #include "tensorflow/python/lib/core/py_util.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
#include <Python.h> #include <Python.h>
namespace tensorflow { namespace tensorflow {
@ -48,6 +52,9 @@ struct PyCall {
// with this "token". // with this "token".
string token; string token;
// True if the call is associated with an EagerPyFunc.
bool eager;
// Inputs and outputs of this function invocation. // Inputs and outputs of this function invocation.
std::vector<Tensor> ins; std::vector<Tensor> ins;
std::vector<Tensor> out; std::vector<Tensor> out;
@ -55,19 +62,26 @@ struct PyCall {
// Givens the 'call', prepares the token and inputs as a python tuple // Givens the 'call', prepares the token and inputs as a python tuple
// that is appropriate for calling the trampoline. // that is appropriate for calling the trampoline.
Status MakeArgTuple(PyCall* call, PyObject** tuple) { Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
int64 n = call->ins.size(); int64 n = call->ins.size();
PyObject* lst = PyList_New(n); PyObject* lst = PyList_New(n);
CHECK(lst); CHECK(lst);
for (int64 i = 0; i < n; ++i) { for (int64 i = 0; i < n; ++i) {
PyObject* arg = nullptr;
const Tensor& t = call->ins[i]; const Tensor& t = call->ins[i];
PyObject* a = nullptr; if (call->eager) {
Status s = ConvertTensorToNdarray(t, &a); arg = EagerTensorFromHandle(TFE_NewTensorHandle(t));
if (!s.ok()) { if (arg == nullptr) {
Py_DECREF(lst); return errors::Internal("Unable to procure EagerTensor from Tensor.");
return s; }
} else {
Status s = ConvertTensorToNdarray(t, &arg);
if (!s.ok()) {
Py_DECREF(lst);
return s;
}
} }
PyList_SetItem(lst, i, a); PyList_SetItem(lst, i, arg);
} }
*tuple = Py_BuildValue("(sN)", call->token.c_str(), lst); *tuple = Py_BuildValue("(sN)", call->token.c_str(), lst);
CHECK(*tuple); CHECK(*tuple);
@ -133,6 +147,18 @@ bool IsSingleNone(PyObject* obj) {
return item == Py_None; return item == Py_None;
} }
// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
Tensor* output_tensor,
TF_Status* tf_status) {
// TODO(akshayka): Lift the restriction requiring output tensors to
// lie in host memory; EagerPyFunc should be able to dispatch ops on GPU
// tensors, so we should eventually implement a GPU kernel for EagerPyFunc.
*output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory(
EagerTensor_Handle(eager_tensor), tf_status);
return StatusFromTF_Status(tf_status);
}
// Calls the registered py function through the trampoline. // Calls the registered py function through the trampoline.
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
*out_log_on_error = true; *out_log_on_error = true;
@ -172,21 +198,37 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
} }
} }
// Process the return values and converts them to tf Tensors. // Process the return values and convert them to TF Tensors.
Status s; Status s;
if (PyList_Check(result)) { if (PyList_Check(result)) {
// 'result' is a list.
call->out.clear(); call->out.clear();
for (int i = 0; i < PyList_Size(result); ++i) { for (int i = 0; i < PyList_Size(result); ++i) {
Tensor t; Tensor t;
s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t); if (call->eager) {
auto tf_status = tensorflow::make_safe(TF_NewStatus());
s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t,
tf_status.get());
} else {
s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
}
if (!s.ok()) { if (!s.ok()) {
break; break;
} }
call->out.push_back(t); call->out.push_back(t);
} }
} else if (EagerTensor_CheckExact(result) || result == Py_None) {
DCHECK(call->eager);
Tensor t;
if (result != Py_None) {
auto tf_status = tensorflow::make_safe(TF_NewStatus());
s = ExtractTensorFromEagerTensor(result, &t, tf_status.get());
if (s.ok()) {
call->out.push_back(t);
}
}
} else if (PyArray_Check(result)) { } else if (PyArray_Check(result)) {
// 'result' is a single ndarray. DCHECK(!call->eager);
if (!IsSingleNone(result)) { if (!IsSingleNone(result)) {
Tensor t; Tensor t;
s = ConvertNdarrayToTensor(result, &t); s = ConvertNdarrayToTensor(result, &t);
@ -375,11 +417,13 @@ class PyFuncOp : public OpKernel {
public: public:
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) { explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
eager_ = type_string() == "EagerPyFunc";
} }
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
PyCall call; PyCall call;
call.token = token_; call.token = token_;
call.eager = eager_;
for (int i = 0; i < ctx->num_inputs(); ++i) { for (int i = 0; i < ctx->num_inputs(); ++i) {
call.ins.push_back(ctx->input(i)); call.ins.push_back(ctx->input(i));
} }
@ -418,9 +462,15 @@ class PyFuncOp : public OpKernel {
private: private:
string token_; string token_;
// True if and only if this op should execute the python function eagerly,
// i.e., if and only if the eager attribute is set.
bool eager_;
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp); TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
}; };
REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp); REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp); REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp);
REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp);
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -24,21 +24,27 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
// Called by py code on initialization. // Called by python code on initialization.
// //
// "trampoline" must represent a python function which has the // "trampoline" must represent a python function which has the
// following signature: // following signature:
// (string, list(ndarray)) -> ndarray | list(ndarray) | python scalar // (string, list(ndarray)) | (string, list(EagerTensor)) ->
// ndarray | list(ndarray) | python scalar |
// EagerTensor | list(EagerTensor) | None
// //
// The trampoline takes two arguments, the first is a string token // The trampoline takes two arguments, the first is a string token
// used by the python frontend's dispatching logic; the second is a // used by the python frontend's dispatching logic; the second is a
// list of numpy ndarrays. // list of numpy ndarrays or EagerTensor objects. It can return a
// single numpy ndarray, a list of numpy ndarrays, a python scalar, an
// EagerTensor, a list of EagerTensors, or None.
// //
// The trampoline can return a single numpy ndarray, a list of numpy // PyFunc requires inputs and outputs to be ndarrays. EagerPyFunc requires
// ndarrays, or a simply python scalar. The C++ runtime converts them, // inputs to be a list of EagerTensors and outputs to be an EagerTensor, a list
// if supported, back to Tensor objects. // of EagerTensors, or None.
// //
// This is called by script_ops.py during its module initialization. // The C++ runtime converts outputs back to Tensor objects.
//
// This function is called by script_ops.py during its module initialization.
// //
// TODO(zhifengc): Support distributed runtime. // TODO(zhifengc): Support distributed runtime.
void InitializePyTrampoline(PyObject* trampoline); void InitializePyTrampoline(PyObject* trampoline);

View File

@ -341,6 +341,7 @@ TruncatedNormal
# script_ops # script_ops
PyFunc PyFunc
PyFuncStateless PyFuncStateless
EagerPyFunc
# sdca_ops # sdca_ops

View File

@ -29,11 +29,41 @@ import numpy as np
import six import six
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import function from tensorflow.python.framework import function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_script_ops from tensorflow.python.ops import gen_script_ops
class EagerFunc(object):
"""A wrapper for a function owned by an EagerPyFunc."""
def __init__(self, func, Tout):
"""Constructs an EagerFunc.
Args:
func: The function to wrap.
Tout: A list of datatypes for the output; an empty list if the output is
None.
"""
self._func = func
self._out_dtypes = Tout
def __call__(self, *args, **kwargs):
"""Passes args, kwargs to `self._func`, which is executed eagerly."""
with context.eager_mode():
ret = self._func(*args, **kwargs)
if isinstance(ret, (tuple, list)):
return [
ops.convert_to_tensor(x, dtype=dtype)
for (x, dtype) in zip(ret, self._out_dtypes)
]
elif ret is None:
return ret
else:
return ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])
class FuncRegistry(object): class FuncRegistry(object):
"""A helper class to keep track of registered py functions. """A helper class to keep track of registered py functions.
@ -91,16 +121,20 @@ class FuncRegistry(object):
if func is None: if func is None:
raise ValueError("callback %s is not found" % token) raise ValueError("callback %s is not found" % token)
ret = func(*args) ret = func(*args)
# Strings seem to lead to a memory leak here if they're not wrapped in a
# list. if isinstance(func, EagerFunc):
if isinstance(ret, six.binary_type): return ret
ret = [ret]
# Ensures that we return either a single numpy array or a list of numpy
# arrays.
if isinstance(ret, (tuple, list)):
return [self._convert(x) for x in ret]
else: else:
return self._convert(ret) # Strings seem to lead to a memory leak here if they're not wrapped in a
# list.
if isinstance(ret, six.binary_type):
ret = [ret]
# Ensures that we return either a single numpy array or a list of numpy
# arrays.
if isinstance(ret, (tuple, list)):
return [self._convert(x) for x in ret]
else:
return self._convert(ret)
def size(self): def size(self):
"""Returns how many functions are currently registered.""" """Returns how many functions are currently registered."""
@ -129,6 +163,86 @@ class CleanupFunc(object):
_py_funcs.remove(self._token) _py_funcs.remove(self._token)
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
"""See documentation for py_func and eager_py_func."""
is_list_or_tuple = False
if isinstance(Tout, (list, tuple)):
is_list_or_tuple = True
else:
Tout = [Tout]
if eager:
func = EagerFunc(func, Tout)
token = _py_funcs.insert(func)
# We tie the registered function's lifetime with the current default graph,
# i.e., when the current graph is destroyed, we remove its py funcs.
graph = ops.get_default_graph()
# pylint: disable=protected-access
while isinstance(graph, function._FuncGraph):
# If the py_func was declared inside a _FuncGraph, its lifetime should be
# bound to that of the outer graph instead.
graph = graph._outer_graph
cleanup = CleanupFunc(token)
# TODO(zhifengc): Consider adding a Graph method to collect
# `cleanup` objects in one of its member.
if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"):
graph._cleanup_py_funcs_used_in_graph = []
# When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph
# will be destroyed and their __del__ will remove the 'token' from
# the funcs registry.
graph._cleanup_py_funcs_used_in_graph.append(cleanup)
# pylint: enable=protected-access
# pylint: disable=protected-access
if eager:
result = gen_script_ops._eager_py_func(
input=inp, token=token, Tout=Tout, name=name)
else:
if stateful:
result = gen_script_ops._py_func(
input=inp, token=token, Tout=Tout, name=name)
else:
result = gen_script_ops._py_func_stateless(
input=inp, token=token, Tout=Tout, name=name)
# pylint: enable=protected-access
return result if is_list_or_tuple else result[0]
def eager_py_func(func, inp, Tout, name=None):
"""Wraps a python function into a TensorFlow op.
When the returned op is executed, `func` is invoked with eager execution
enabled. Inputs are Tensor objects and func must return None or objects
that may be converted to Tensor objects.
This function has the same limitations as `py_func` with respect to
serialization and distribution.
Args:
func: A Python function which accepts a list of `Tensor` objects
having element types that match the corresponding `tf.Tensor` objects
in `inp` and returns a list of `Tensor` objects (or a single
`Tensor`, or `None`) having element types that match the
corresponding values in `Tout`.
inp: A list of `Tensor` objects.
Tout: A list or tuple of tensorflow data types or a single tensorflow data
type if there is only one, indicating what `func` returns; an empty list
if no value is returned (i.e., if the return value is `None`).
name: A name for the operation (optional).
Returns:
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
if `func` returns None.
"""
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
def py_func(func, inp, Tout, stateful=True, name=None): def py_func(func, inp, Tout, stateful=True, name=None):
"""Wraps a python function and uses it as a TensorFlow op. """Wraps a python function and uses it as a TensorFlow op.
@ -182,46 +296,12 @@ def py_func(func, inp, Tout, stateful=True, name=None):
Returns: Returns:
A list of `Tensor` or a single `Tensor` which `func` computes. A list of `Tensor` or a single `Tensor` which `func` computes.
""" """
token = _py_funcs.insert(func) return _internal_py_func(
# We tie the registered function's life-time with the current func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
# default graph. I.e., when the current graph is destroyed, we
# should remove its py funcs.
g = ops.get_default_graph()
# pylint: disable=protected-access
while isinstance(g, function._FuncGraph):
# If the py_func was declared inside a _FuncGraph, its lifetime should be
# bound to that of the outer graph instead.
g = g._outer_graph
cleanup = CleanupFunc(token)
# TODO(zhifengc): Consider adding a Graph method to collect
# `cleanup` objects in one of its member.
if not hasattr(g, "_cleanup_py_funcs_used_in_graph"):
g._cleanup_py_funcs_used_in_graph = []
# When g is destroyed, elements in _cleanup_py_funcs_used_in_graph
# will be destroyed and their __del__ will remove the 'token' from
# the funcs registry.
g._cleanup_py_funcs_used_in_graph.append(cleanup)
# pylint: enable=protected-access
if isinstance(Tout, (list, tuple)):
is_list_or_tuple = True
else:
Tout = [Tout]
is_list_or_tuple = False
# pylint: disable=protected-access
if stateful:
result = gen_script_ops._py_func(
input=inp, token=token, Tout=Tout, name=name)
else:
result = gen_script_ops._py_func_stateless(
input=inp, token=token, Tout=Tout, name=name)
# pylint: enable=protected-access
return result if is_list_or_tuple else result[0]
# TODO(akshayka): PyFuncs where the 'eager' attribute is set to True should be
# differentiable, i.e., the gradient of PyFunc should propagate Nones if the
# eager attribute is not set, and otherwise, it should return the gradient.
ops.NotDifferentiable("PyFunc") ops.NotDifferentiable("PyFunc")
ops.NotDifferentiable("PyFuncStateless") ops.NotDifferentiable("PyFuncStateless")