1. Add helper for casting EagerTensor* to AbstractTensorHandle*. This allows us to directly work with EagerTensors and will also allow using tf.data with unified APIs.

2. Enable unified_api_test to run with TFRT.

PiperOrigin-RevId: 331662847
Change-Id: I78e53f4eb32d40b6fc6672af3caf2a5c3c74a22c
This commit is contained in:
Saurabh Saxena 2020-09-14 17:15:52 -07:00 committed by TensorFlower Gardener
parent f08c6b6bc1
commit c11debf86c
5 changed files with 61 additions and 50 deletions
tensorflow
c/eager
python

View File

@ -111,6 +111,7 @@ filegroup(
"mnist_gradients_testutil.h",
"tape.h",
"tfe_cancellation_manager_internal.h",
"tfe_context_internal.h",
"tfe_executor_internal.h",
"tfe_monitoring_internal.h",
"tfe_op_attrs_internal.h",

View File

@ -24,6 +24,7 @@ bool EagerTensor_CheckExact(const PyObject* o);
tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor);
TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
namespace tensorflow {

View File

@ -120,6 +120,7 @@ cuda_py_test(
"no_pip",
"no_windows", # b/168218876
],
tfrt_enabled = True,
deps = [
":_unified_api",
":context_stack",

View File

@ -22,15 +22,18 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_function.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
@ -97,23 +100,13 @@ PYBIND11_MODULE(_unified_api, m) {
}
return dyn_cast<TracingContext>(ctx);
});
m.def("NewImmediateExecutionContext", [](bool use_tfrt) {
Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
TFE_ContextOptions options;
options.use_tfrt = use_tfrt;
auto* ctx = unwrap(TF_NewEagerExecutionContext(&options, status.get()));
MaybeRaiseRegisteredFromTFStatus(status.get());
m.def("EagerContextToImmediateExecutionContext", [](py::handle& obj) {
TFE_Context* ctx =
static_cast<TFE_Context*>(PyCapsule_GetPointer(obj.ptr(), nullptr));
if (!ctx) {
MaybeRaiseRegisteredFromStatus(Internal("Creating eager ctx failed"));
MaybeRaiseRegisteredFromStatus(InvalidArgument("TFE_Context is nullptr"));
}
if (!isa<ImmediateExecutionContext>(ctx)) {
// TODO(srbs): Add a helper to convert the kind enum to a user-friendly
// string.
MaybeRaiseRegisteredFromStatus(
Internal("TF_NewEagerExecutionContext must return an ",
"ImmediateExecutionContext, found ", ctx->getKind()));
}
return dyn_cast<ImmediateExecutionContext>(ctx);
return unwrap(ctx);
});
// Unified execution context.
@ -175,14 +168,14 @@ PYBIND11_MODULE(_unified_api, m) {
return f;
});
py::class_<ImmediateExecutionContext, AbstractContext, ImmediateContextPtr>(
m, "ImmediateExecutionContext")
.def("CreateFloatScalarHandle",
[](ImmediateExecutionContext* self, float value) {
auto* tensor = self->CreateFloatScalar(value);
auto* handle = self->CreateLocalHandle(tensor);
return reinterpret_cast<AbstractTensorHandle*>(handle);
});
// Note: This does not take ownership of the C++ context, the lifetime of
// which is managed by the python `Context` and is expected to outlive this
// object.
// TODO(srbs): Make AbstractContext refcounted so that the above comment is
// not needed.
py::class_<ImmediateExecutionContext, AbstractContext,
std::unique_ptr<ImmediateExecutionContext, py::nodelete>>
ImmediateExecutionContext(m, "ImmediateExecutionContext");
// Unified execution operation.
py::class_<AbstractOperation, AbstractOperationPtr>(m, "AbstractOperation")
@ -246,5 +239,18 @@ PYBIND11_MODULE(_unified_api, m) {
MaybeRaiseRegisteredFromStatus(s.status);
return Pyo(result);
});
m.def("EagerTensorToImmediateExecutionTensorHandle", [](py::object handle) {
if (!EagerTensor_CheckExact(handle.ptr())) {
MaybeRaiseRegisteredFromStatus(
InvalidArgument("EagerTensorToImmediateExecutionTensorHandle called "
"with non-EagerTensor."));
}
TFE_TensorHandle* eager_tensor = EagerTensor_Handle(handle.ptr());
auto t = static_cast<AbstractTensorHandle*>(unwrap(eager_tensor));
t->Ref();
return t;
});
py::class_<AbstractFunction> AbstractFunction(m, "AbstractFunction");
}

View File

@ -20,6 +20,9 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework.experimental import _unified_api
from tensorflow.python.framework.experimental import context_stack as context_lib
from tensorflow.python.framework.experimental import def_function
@ -27,45 +30,44 @@ from tensorflow.python.framework.experimental import math_ops
from tensorflow.python.framework.experimental import tape as tape_lib
from tensorflow.python.platform import test
NewImmediateExecutionContext = _unified_api.NewImmediateExecutionContext
SetTracingImplementation = _unified_api.SetTracingImplementation
TensorCastHelper = _unified_api.EagerTensorToImmediateExecutionTensorHandle
def get_immediate_execution_context():
context.context().ensure_initialized()
return _unified_api.EagerContextToImmediateExecutionContext(
context.context()._handle)
class UnifiedApiTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters([
("EagerGraph", False, False),
("EagerMlir", False, True),
# TODO(srbs): Enable for TFRT. Segfaults right now.
# ("TfrtGraph", True, False),
# ("TfrtMlir", True, True),
("Graph", False),
("Mlir", True),
])
def testAdd(self, use_tfrt, use_mlir):
def testAdd(self, use_mlir):
if use_mlir:
SetTracingImplementation("mlir")
def model(a, b):
return math_ops.add(a, b)
eager_ctx = NewImmediateExecutionContext(use_tfrt)
with context_lib.set_default(eager_ctx):
a = eager_ctx.CreateFloatScalarHandle(1.)
b = eager_ctx.CreateFloatScalarHandle(2.)
with context_lib.set_default(get_immediate_execution_context()):
a = TensorCastHelper(constant_op.constant([1., 2.]))
b = TensorCastHelper(constant_op.constant([3., 4.]))
func_output = def_function.function(model)(a, b)
self.assertAllEqual(func_output.numpy(), 3.0)
self.assertAllEqual(func_output.numpy(), [4., 6.])
eager_output = model(a, b)
self.assertAllEqual(eager_output.numpy(), 3.0)
self.assertAllEqual(eager_output.numpy(), [4., 6.])
@parameterized.named_parameters([
("EagerGraph", False, False),
("EagerMlir", False, True),
# TODO(srbs): Enable for TFRT. Segfaults right now.
# ("TfrtGraph", True, False),
# ("TfrtMlir", True, True),
("Graph", False),
("Mlir", True),
])
def testAddGrad(self, use_tfrt, use_mlir):
def testAddGrad(self, use_mlir):
if use_mlir:
SetTracingImplementation("mlir")
@ -77,19 +79,19 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
grads = tape.gradient(result, [a, b])
return grads
eager_ctx = NewImmediateExecutionContext(use_tfrt)
with context_lib.set_default(eager_ctx):
a = eager_ctx.CreateFloatScalarHandle(1.)
b = eager_ctx.CreateFloatScalarHandle(2.)
with context_lib.set_default(get_immediate_execution_context()):
a = TensorCastHelper(constant_op.constant([1., 2.]))
b = TensorCastHelper(constant_op.constant([3., 4.]))
func_outputs = def_function.function(model)(a, b)
self.assertAllEqual(func_outputs[0].numpy(), 1.0)
self.assertAllEqual(func_outputs[1].numpy(), 1.0)
self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0])
self.assertAllEqual(func_outputs[1].numpy(), [1.0, 1.0])
eager_outputs = model(a, b)
self.assertAllEqual(eager_outputs[0].numpy(), 1.0)
self.assertAllEqual(eager_outputs[1].numpy(), 1.0)
self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0])
self.assertAllEqual(eager_outputs[1].numpy(), [1.0, 1.0])
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()