Add tfe.py_func, a tf.py_func-like construct that wraps a Python function and executes it eagerly.

In particular, an EagerPyFunc op is added that wraps a Python function and executes it eagerly. The wrapped function should take Tensors as inputs and return Tensors as outputs. Because functions wrapped in an EagerPyFunc are executed eagerly, they can make use of TensorFlow operations.

EagerPyFunc should be differentiable, in principle; a gradient will be implemented and registered in a future change. Once a gradient is implemented, tfe.py_func will probably be the easiest mechanism for experimenting with custom ops.

tfe.py_func will also make it easier to translate python functions with side-effects into defun-able code.

PiperOrigin-RevId: 178303818
This commit is contained in:
Akshay Agrawal 2017-12-07 15:18:46 -08:00 committed by TensorFlower Gardener
parent 2d4c29cd6a
commit f37380b064
12 changed files with 399 additions and 127 deletions

View File

@ -19,6 +19,7 @@ py_library(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:numerics",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:backprop",

View File

@ -23,6 +23,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@list_devices
@@num_gpus
@@py_func
@@defun
@@implicit_gradients
@@implicit_value_and_gradients
@ -101,8 +102,10 @@ from tensorflow.python.framework.test_util import IsolateTest
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.ops.variable_scope import EagerVariableStore
from tensorflow.python.ops import script_ops
from tensorflow.python.util.all_util import remove_undocumented
py_func = script_ops.eager_py_func
defun = function.defun
implicit_gradients = backprop.implicit_grad
implicit_value_and_gradients = backprop.implicit_val_and_grad

View File

@ -0,0 +1,8 @@
op {
graph_op_name: "EagerPyFunc"
summary: "Eagerly executes a python function to compute func(input)->output. The"
description: <<END
semantics of the input, output, and attributes are the same as those for
PyFunc.
END
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "EagerPyFunc"
visibility: HIDDEN
}

View File

@ -51,4 +51,18 @@ REGISTER_OP("PyFuncStateless")
A stateless version of PyFunc.
)doc");
REGISTER_OP("EagerPyFunc")
.Input("input: Tin")
.Output("output: Tout")
.Attr("token: string")
.Attr("Tin: list(type) >= 0")
.Attr("Tout: list(type) >=0")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Eagerly executes a python function to compute func(input)->output. The
semantics of the input, output, and attributes are the same as those for
PyFunc.
)doc");
} // namespace tensorflow

View File

@ -280,10 +280,14 @@ cc_library(
":ndarray_tensor_bridge",
":numpy_lib",
":py_util",
":safe_ptr",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:script_ops_op_lib",
"//tensorflow/python/eager:pywrap_tfe_lib",
"//third_party/py/numpy:headers",
"//util/python:python_headers",
],

View File

@ -1645,6 +1645,8 @@ cuda_py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:script_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
],
tags = ["no_windows"],
)

View File

@ -23,82 +23,93 @@ from six.moves import queue
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
class PyOpTest(test.TestCase):
def np_func(x, y):
return np.sinh(x) + np.cosh(y)
def testBasic(self):
def my_func(x, y):
return np.sinh(x) + np.cosh(y)
def matmul(x, y):
return math_ops.matmul(x, y)
# single type
class PyFuncTest(test.TestCase):
"""Encapsulates tests for py_func and eager_py_func."""
# ----- Tests for py_func -----
def testSingleType(self):
with self.test_session():
x = constant_op.constant(1.0, dtypes.float32)
y = constant_op.constant(2.0, dtypes.float32)
z = script_ops.py_func(my_func, [x, y], dtypes.float32)
self.assertEqual(z.eval(), my_func(1.0, 2.0).astype(np.float32))
z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
# scalar
def testScalar(self):
with self.test_session():
x = constant_op.constant(1.0, dtypes.float32)
y = constant_op.constant(2.0, dtypes.float32)
z = script_ops.py_func(my_func, [x, y], [dtypes.float32])
self.assertEqual(z[0].eval(), my_func(1.0, 2.0).astype(np.float32))
z = self.evaluate(
script_ops.eager_py_func(np_func, [x, y], [dtypes.float32]))
self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
# array
def testArray(self):
with self.test_session():
x = constant_op.constant([1.0, 2.0], dtypes.float64)
y = constant_op.constant([2.0, 3.0], dtypes.float64)
z = script_ops.py_func(my_func, [x, y], [dtypes.float64])
self.assertAllEqual(z[0].eval(),
my_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
self.assertAllEqual(z[0],
np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
# a bit exotic type (complex64)
def testComplexType(self):
with self.test_session():
x = constant_op.constant(1 + 2j, dtypes.complex64)
y = constant_op.constant(3 + 4j, dtypes.complex64)
z, = script_ops.py_func(my_func, [x, y], [dtypes.complex64])
self.assertAllClose(z.eval(), my_func(1 + 2j, 3 + 4j))
z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
# a bit excotic function (rfft)
def testRFFT(self):
with self.test_session():
x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
def rfft(x):
return np.fft.rfft(x).astype(np.complex64)
y, = script_ops.py_func(rfft, [x], [dtypes.complex64])
self.assertAllClose(y.eval(), np.fft.rfft([1., 2., 3., 4.]))
y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64))
self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
# returns a python literal.
def testPythonLiteral(self):
with self.test_session():
def literal(x):
return 1.0 if x == 0.0 else 0.0
return 1.0 if float(x) == 0.0 else 0.0
x = constant_op.constant(0.0, dtypes.float64)
y, = script_ops.py_func(literal, [x], [dtypes.float64])
self.assertAllClose(y.eval(), 1.0)
y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64))
self.assertAllClose(y, 1.0)
# returns a list
def testList(self):
with self.test_session():
def list_func(x):
return [x, x + 1]
x = constant_op.constant(0.0, dtypes.float64)
y, z = script_ops.py_func(list_func, [x], [dtypes.float64] * 2)
self.assertAllClose(y.eval(), 0.0)
self.assertAllClose(z.eval(), 1.0)
y = self.evaluate(
script_ops.py_func(list_func, [x], [dtypes.float64] * 2))
self.assertAllClose(y, [0.0, 1.0])
def testTuple(self):
# returns a tuple
with self.test_session():
@ -106,17 +117,17 @@ class PyOpTest(test.TestCase):
return x, x + 1
x = constant_op.constant(0.0, dtypes.float64)
y, z = script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2)
self.assertAllClose(y.eval(), 0.0)
self.assertAllClose(z.eval(), 1.0)
y = self.evaluate(
script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2))
self.assertAllClose(y, [0.0, 1.0])
# returns a tuple, Tout and inp a tuple
with self.test_session():
x = constant_op.constant(0.0, dtypes.float64)
y, z = script_ops.py_func(tuple_func, (x,), (dtypes.float64,
dtypes.float64))
self.assertAllClose(y.eval(), 0.0)
self.assertAllClose(z.eval(), 1.0)
y = self.evaluate(
script_ops.py_func(tuple_func, (x,),
(dtypes.float64, dtypes.float64)))
self.assertAllClose(y, [0.0, 1.0])
def testStrings(self):
@ -128,10 +139,12 @@ class PyOpTest(test.TestCase):
with self.test_session():
x = constant_op.constant([b"hello", b"hi"], dtypes.string)
y, = script_ops.py_func(read_fixed_length_numpy_strings, [],
[dtypes.string])
z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"])
y = self.evaluate(
script_ops.py_func(read_fixed_length_numpy_strings, [],
dtypes.string))
z = self.evaluate(
script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
self.assertAllEqual(z, [b"hello there", b"hi there"])
def testStringsAreConvertedToBytes(self):
@ -143,10 +156,12 @@ class PyOpTest(test.TestCase):
with self.test_session():
x = constant_op.constant(["hello", "hi"], dtypes.string)
y, = script_ops.py_func(read_fixed_length_numpy_strings, [],
[dtypes.string])
z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"])
y = self.evaluate(
script_ops.py_func(read_fixed_length_numpy_strings, [],
dtypes.string))
z = self.evaluate(
script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
self.assertAllEqual(z, [b"hello there", b"hi there"])
def testObjectArraysAreConvertedToBytes(self):
@ -186,16 +201,8 @@ class PyOpTest(test.TestCase):
def testNoInput(self):
with self.test_session():
x, = script_ops.py_func(lambda: 42.0, [], [dtypes.float64])
self.assertAllClose(x.eval(), 42.0)
def testCleanup(self):
for _ in xrange(1000):
g = ops.Graph()
with g.as_default():
c = constant_op.constant([1.], dtypes.float32)
_ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
self.assertTrue(script_ops._py_funcs.size() < 100)
x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
self.assertAllClose(x, 42.0)
def testAlias(self):
with self.test_session():
@ -242,8 +249,8 @@ class PyOpTest(test.TestCase):
# Create a numpy array aliasing a tensor and a tensor aliasing this array
z, = script_ops.py_func(ident, [p], [dtypes.float32])
z += 0.0 # Makes sure we release the tensor aliasing the numpy array x[0]
# above instead of using its memory as the return value of
# session.run
# above instead of using its memory as the return value of
# session.run
self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
def testStateful(self):
@ -319,10 +326,10 @@ class PyOpTest(test.TestCase):
def value(self):
return self._value
with self.test_session() as sess:
with self.test_session():
s = State()
op = s.increment(constant_op.constant(2, dtypes.int64))
ret = sess.run(op)
ret = self.evaluate(op)
self.assertIsNone(ret)
self.assertAllEqual([3], s.value)
@ -336,15 +343,24 @@ class PyOpTest(test.TestCase):
with self.test_session() as sess:
self.assertEqual(sess.run(f), [])
def _testExceptionHandling(self, py_exp, tf_exp):
def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
def raise_exception():
raise py_exp("blah") # pylint: disable=not-callable
f = script_ops.py_func(raise_exception, [], [])
with self.test_session() as sess:
if eager:
if context.in_eager_mode():
with self.assertRaisesRegexp(tf_exp, "blah"):
f = script_ops.eager_py_func(raise_exception, [], [])
return
else:
f = script_ops.eager_py_func(raise_exception, [], [])
else:
f = script_ops.py_func(raise_exception, [], [])
with self.test_session():
with self.assertRaisesRegexp(tf_exp, "blah"):
sess.run(f)
self.evaluate(f)
def testExceptionHandling(self):
self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
@ -358,6 +374,89 @@ class PyOpTest(test.TestCase):
self._testExceptionHandling(WeirdError, errors.UnknownError)
# ----- Tests shared by py_func and eager_py_func -----
def testCleanup(self):
for _ in xrange(1000):
g = ops.Graph()
with g.as_default():
c = constant_op.constant([1.], dtypes.float32)
_ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
_ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
self.assertTrue(script_ops._py_funcs.size() < 100)
# ----- Tests for eager_py_func -----
@test_util.run_in_graph_and_eager_modes()
def testEagerSingleOutputInt32(self):
a = array_ops.ones((3, 3), dtype=dtypes.int32)
x = array_ops.ones((3, 1), dtype=dtypes.int32)
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
with self.test_session():
ret = self.evaluate(output)
self.assertAllEqual(ret, [[3], [3], [3]])
@test_util.run_in_graph_and_eager_modes()
def testEagerSingleOutputFloat32(self):
a = array_ops.ones((3, 3), dtype=dtypes.float32)
x = array_ops.ones((3, 1), dtype=dtypes.float32)
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
with self.test_session():
ret = self.evaluate(output)
self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
@test_util.run_in_graph_and_eager_modes()
def testEagerArrayOutput(self):
a = array_ops.ones((3, 3), dtype=dtypes.int32)
x = array_ops.ones((3, 1), dtype=dtypes.int32)
output = script_ops.eager_py_func(
lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.int32])
with self.test_session():
ret = self.evaluate(output)
self.assertAllEqual(ret, [[[3], [3], [3]]])
@test_util.run_in_graph_and_eager_modes()
def testEagerReturnNone(self):
def no_return_value():
return
output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
ret = self.evaluate(output)
if context.in_eager_mode():
self.assertEquals(len(ret), 0)
else:
self.assertIsNone(ret)
@test_util.run_in_graph_and_eager_modes()
def testEagerPyFuncInDefun(self):
def wrapper():
a = array_ops.ones((3, 3), dtype=dtypes.int32)
x = array_ops.ones((3, 1), dtype=dtypes.int32)
return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
wrapped = function.defun(wrapper)
ret = self.evaluate(wrapped())
self.assertAllEqual(ret, [[3], [3], [3]])
@test_util.run_in_graph_and_eager_modes()
def testEagerExceptionHandling(self):
self._testExceptionHandling(
ValueError, errors.InvalidArgumentError, eager=True)
self._testExceptionHandling(
TypeError, errors.InvalidArgumentError, eager=True)
self._testExceptionHandling(
StopIteration, errors.OutOfRangeError, eager=True)
self._testExceptionHandling(
MemoryError, errors.ResourceExhaustedError, eager=True)
self._testExceptionHandling(
NotImplementedError, errors.UnimplementedError, eager=True)
class WeirdError(Exception):
pass
self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True)
if __name__ == "__main__":
test.main()

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <array>
#include "numpy/arrayobject.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
@ -25,8 +27,10 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/py_util.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
#include <Python.h>
namespace tensorflow {
@ -48,6 +52,9 @@ struct PyCall {
// with this "token".
string token;
// True if the call is associated with an EagerPyFunc.
bool eager;
// Inputs and outputs of this function invocation.
std::vector<Tensor> ins;
std::vector<Tensor> out;
@ -55,19 +62,26 @@ struct PyCall {
// Givens the 'call', prepares the token and inputs as a python tuple
// that is appropriate for calling the trampoline.
Status MakeArgTuple(PyCall* call, PyObject** tuple) {
Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
int64 n = call->ins.size();
PyObject* lst = PyList_New(n);
CHECK(lst);
for (int64 i = 0; i < n; ++i) {
PyObject* arg = nullptr;
const Tensor& t = call->ins[i];
PyObject* a = nullptr;
Status s = ConvertTensorToNdarray(t, &a);
if (!s.ok()) {
Py_DECREF(lst);
return s;
if (call->eager) {
arg = EagerTensorFromHandle(TFE_NewTensorHandle(t));
if (arg == nullptr) {
return errors::Internal("Unable to procure EagerTensor from Tensor.");
}
} else {
Status s = ConvertTensorToNdarray(t, &arg);
if (!s.ok()) {
Py_DECREF(lst);
return s;
}
}
PyList_SetItem(lst, i, a);
PyList_SetItem(lst, i, arg);
}
*tuple = Py_BuildValue("(sN)", call->token.c_str(), lst);
CHECK(*tuple);
@ -133,6 +147,18 @@ bool IsSingleNone(PyObject* obj) {
return item == Py_None;
}
// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
Tensor* output_tensor,
TF_Status* tf_status) {
// TODO(akshayka): Lift the restriction requiring output tensors to
// lie in host memory; EagerPyFunc should be able to dispatch ops on GPU
// tensors, so we should eventually implement a GPU kernel for EagerPyFunc.
*output_tensor = *TFE_TensorHandleUnderlyingTensorInHostMemory(
EagerTensor_Handle(eager_tensor), tf_status);
return StatusFromTF_Status(tf_status);
}
// Calls the registered py function through the trampoline.
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
*out_log_on_error = true;
@ -172,21 +198,37 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
}
}
// Process the return values and converts them to tf Tensors.
// Process the return values and convert them to TF Tensors.
Status s;
if (PyList_Check(result)) {
// 'result' is a list.
call->out.clear();
for (int i = 0; i < PyList_Size(result); ++i) {
Tensor t;
s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
if (call->eager) {
auto tf_status = tensorflow::make_safe(TF_NewStatus());
s = ExtractTensorFromEagerTensor(PyList_GetItem(result, i), &t,
tf_status.get());
} else {
s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
}
if (!s.ok()) {
break;
}
call->out.push_back(t);
}
} else if (EagerTensor_CheckExact(result) || result == Py_None) {
DCHECK(call->eager);
Tensor t;
if (result != Py_None) {
auto tf_status = tensorflow::make_safe(TF_NewStatus());
s = ExtractTensorFromEagerTensor(result, &t, tf_status.get());
if (s.ok()) {
call->out.push_back(t);
}
}
} else if (PyArray_Check(result)) {
// 'result' is a single ndarray.
DCHECK(!call->eager);
if (!IsSingleNone(result)) {
Tensor t;
s = ConvertNdarrayToTensor(result, &t);
@ -375,11 +417,13 @@ class PyFuncOp : public OpKernel {
public:
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
eager_ = type_string() == "EagerPyFunc";
}
void Compute(OpKernelContext* ctx) override {
PyCall call;
call.token = token_;
call.eager = eager_;
for (int i = 0; i < ctx->num_inputs(); ++i) {
call.ins.push_back(ctx->input(i));
}
@ -418,9 +462,15 @@ class PyFuncOp : public OpKernel {
private:
string token_;
// True if and only if this op should execute the python function eagerly,
// i.e., if and only if the eager attribute is set.
bool eager_;
TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
};
REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
REGISTER_KERNEL_BUILDER(Name("PyFuncStateless").Device(DEVICE_CPU), PyFuncOp);
REGISTER_KERNEL_BUILDER(Name("EagerPyFunc").Device(DEVICE_CPU), PyFuncOp);
} // end namespace tensorflow

View File

@ -24,21 +24,27 @@ limitations under the License.
namespace tensorflow {
// Called by py code on initialization.
// Called by python code on initialization.
//
// "trampoline" must represent a python function which has the
// following signature:
// (string, list(ndarray)) -> ndarray | list(ndarray) | python scalar
// (string, list(ndarray)) | (string, list(EagerTensor)) ->
// ndarray | list(ndarray) | python scalar |
// EagerTensor | list(EagerTensor) | None
//
// The trampoline takes two arguments, the first is a string token
// used by the python frontend's dispatching logic; the second is a
// list of numpy ndarrays.
// list of numpy ndarrays or EagerTensor objects. It can return a
// single numpy ndarray, a list of numpy ndarrays, a python scalar, an
// EagerTensor, a list of EagerTensors, or None.
//
// The trampoline can return a single numpy ndarray, a list of numpy
// ndarrays, or a simply python scalar. The C++ runtime converts them,
// if supported, back to Tensor objects.
// PyFunc requires inputs and outputs to be ndarrays. EagerPyFunc requires
// inputs to be a list of EagerTensors and outputs to be an EagerTensor, a list
// of EagerTensors, or None.
//
// This is called by script_ops.py during its module initialization.
// The C++ runtime converts outputs back to Tensor objects.
//
// This function is called by script_ops.py during its module initialization.
//
// TODO(zhifengc): Support distributed runtime.
void InitializePyTrampoline(PyObject* trampoline);

View File

@ -341,6 +341,7 @@ TruncatedNormal
# script_ops
PyFunc
PyFuncStateless
EagerPyFunc
# sdca_ops

View File

@ -29,11 +29,41 @@ import numpy as np
import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_script_ops
class EagerFunc(object):
"""A wrapper for a function owned by an EagerPyFunc."""
def __init__(self, func, Tout):
"""Constructs an EagerFunc.
Args:
func: The function to wrap.
Tout: A list of datatypes for the output; an empty list if the output is
None.
"""
self._func = func
self._out_dtypes = Tout
def __call__(self, *args, **kwargs):
"""Passes args, kwargs to `self._func`, which is executed eagerly."""
with context.eager_mode():
ret = self._func(*args, **kwargs)
if isinstance(ret, (tuple, list)):
return [
ops.convert_to_tensor(x, dtype=dtype)
for (x, dtype) in zip(ret, self._out_dtypes)
]
elif ret is None:
return ret
else:
return ops.convert_to_tensor(ret, dtype=self._out_dtypes[0])
class FuncRegistry(object):
"""A helper class to keep track of registered py functions.
@ -91,16 +121,20 @@ class FuncRegistry(object):
if func is None:
raise ValueError("callback %s is not found" % token)
ret = func(*args)
# Strings seem to lead to a memory leak here if they're not wrapped in a
# list.
if isinstance(ret, six.binary_type):
ret = [ret]
# Ensures that we return either a single numpy array or a list of numpy
# arrays.
if isinstance(ret, (tuple, list)):
return [self._convert(x) for x in ret]
if isinstance(func, EagerFunc):
return ret
else:
return self._convert(ret)
# Strings seem to lead to a memory leak here if they're not wrapped in a
# list.
if isinstance(ret, six.binary_type):
ret = [ret]
# Ensures that we return either a single numpy array or a list of numpy
# arrays.
if isinstance(ret, (tuple, list)):
return [self._convert(x) for x in ret]
else:
return self._convert(ret)
def size(self):
"""Returns how many functions are currently registered."""
@ -129,6 +163,86 @@ class CleanupFunc(object):
_py_funcs.remove(self._token)
def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
"""See documentation for py_func and eager_py_func."""
is_list_or_tuple = False
if isinstance(Tout, (list, tuple)):
is_list_or_tuple = True
else:
Tout = [Tout]
if eager:
func = EagerFunc(func, Tout)
token = _py_funcs.insert(func)
# We tie the registered function's lifetime with the current default graph,
# i.e., when the current graph is destroyed, we remove its py funcs.
graph = ops.get_default_graph()
# pylint: disable=protected-access
while isinstance(graph, function._FuncGraph):
# If the py_func was declared inside a _FuncGraph, its lifetime should be
# bound to that of the outer graph instead.
graph = graph._outer_graph
cleanup = CleanupFunc(token)
# TODO(zhifengc): Consider adding a Graph method to collect
# `cleanup` objects in one of its member.
if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"):
graph._cleanup_py_funcs_used_in_graph = []
# When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph
# will be destroyed and their __del__ will remove the 'token' from
# the funcs registry.
graph._cleanup_py_funcs_used_in_graph.append(cleanup)
# pylint: enable=protected-access
# pylint: disable=protected-access
if eager:
result = gen_script_ops._eager_py_func(
input=inp, token=token, Tout=Tout, name=name)
else:
if stateful:
result = gen_script_ops._py_func(
input=inp, token=token, Tout=Tout, name=name)
else:
result = gen_script_ops._py_func_stateless(
input=inp, token=token, Tout=Tout, name=name)
# pylint: enable=protected-access
return result if is_list_or_tuple else result[0]
def eager_py_func(func, inp, Tout, name=None):
"""Wraps a python function into a TensorFlow op.
When the returned op is executed, `func` is invoked with eager execution
enabled. Inputs are Tensor objects and func must return None or objects
that may be converted to Tensor objects.
This function has the same limitations as `py_func` with respect to
serialization and distribution.
Args:
func: A Python function which accepts a list of `Tensor` objects
having element types that match the corresponding `tf.Tensor` objects
in `inp` and returns a list of `Tensor` objects (or a single
`Tensor`, or `None`) having element types that match the
corresponding values in `Tout`.
inp: A list of `Tensor` objects.
Tout: A list or tuple of tensorflow data types or a single tensorflow data
type if there is only one, indicating what `func` returns; an empty list
if no value is returned (i.e., if the return value is `None`).
name: A name for the operation (optional).
Returns:
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
if `func` returns None.
"""
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
def py_func(func, inp, Tout, stateful=True, name=None):
"""Wraps a python function and uses it as a TensorFlow op.
@ -182,46 +296,12 @@ def py_func(func, inp, Tout, stateful=True, name=None):
Returns:
A list of `Tensor` or a single `Tensor` which `func` computes.
"""
token = _py_funcs.insert(func)
# We tie the registered function's life-time with the current
# default graph. I.e., when the current graph is destroyed, we
# should remove its py funcs.
g = ops.get_default_graph()
# pylint: disable=protected-access
while isinstance(g, function._FuncGraph):
# If the py_func was declared inside a _FuncGraph, its lifetime should be
# bound to that of the outer graph instead.
g = g._outer_graph
cleanup = CleanupFunc(token)
# TODO(zhifengc): Consider adding a Graph method to collect
# `cleanup` objects in one of its member.
if not hasattr(g, "_cleanup_py_funcs_used_in_graph"):
g._cleanup_py_funcs_used_in_graph = []
# When g is destroyed, elements in _cleanup_py_funcs_used_in_graph
# will be destroyed and their __del__ will remove the 'token' from
# the funcs registry.
g._cleanup_py_funcs_used_in_graph.append(cleanup)
# pylint: enable=protected-access
if isinstance(Tout, (list, tuple)):
is_list_or_tuple = True
else:
Tout = [Tout]
is_list_or_tuple = False
# pylint: disable=protected-access
if stateful:
result = gen_script_ops._py_func(
input=inp, token=token, Tout=Tout, name=name)
else:
result = gen_script_ops._py_func_stateless(
input=inp, token=token, Tout=Tout, name=name)
# pylint: enable=protected-access
return result if is_list_or_tuple else result[0]
return _internal_py_func(
func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
# TODO(akshayka): PyFuncs where the 'eager' attribute is set to True should be
# differentiable, i.e., the gradient of PyFunc should propagate Nones if the
# eager attribute is not set, and otherwise, it should return the gradient.
ops.NotDifferentiable("PyFunc")
ops.NotDifferentiable("PyFuncStateless")