1. Add helper for casting EagerTensor* to AbstractTensorHandle*. This allows us to directly work with EagerTensors and will also allow using tf.data with unified APIs.

2. Enable unified_api_test to run with TFRT.

PiperOrigin-RevId: 331662847
Change-Id: I78e53f4eb32d40b6fc6672af3caf2a5c3c74a22c
This commit is contained in:
Saurabh Saxena 2020-09-14 17:15:52 -07:00 committed by TensorFlower Gardener
parent f08c6b6bc1
commit c11debf86c
5 changed files with 61 additions and 50 deletions

View File

@ -111,6 +111,7 @@ filegroup(
"mnist_gradients_testutil.h", "mnist_gradients_testutil.h",
"tape.h", "tape.h",
"tfe_cancellation_manager_internal.h", "tfe_cancellation_manager_internal.h",
"tfe_context_internal.h",
"tfe_executor_internal.h", "tfe_executor_internal.h",
"tfe_monitoring_internal.h", "tfe_monitoring_internal.h",
"tfe_op_attrs_internal.h", "tfe_op_attrs_internal.h",

View File

@ -24,6 +24,7 @@ bool EagerTensor_CheckExact(const PyObject* o);
tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor); tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor); tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor); tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor);
TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
namespace tensorflow { namespace tensorflow {

View File

@ -120,6 +120,7 @@ cuda_py_test(
"no_pip", "no_pip",
"no_windows", # b/168218876 "no_windows", # b/168218876
], ],
tfrt_enabled = True,
deps = [ deps = [
":_unified_api", ":_unified_api",
":context_stack", ":context_stack",

View File

@ -22,15 +22,18 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_function.h" #include "tensorflow/c/eager/abstract_function.h"
#include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/python/eager/pywrap_tensor.h" #include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/lib/core/pybind11_lib.h" #include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h" #include "tensorflow/python/lib/core/pybind11_status.h"
@ -97,23 +100,13 @@ PYBIND11_MODULE(_unified_api, m) {
} }
return dyn_cast<TracingContext>(ctx); return dyn_cast<TracingContext>(ctx);
}); });
m.def("NewImmediateExecutionContext", [](bool use_tfrt) { m.def("EagerContextToImmediateExecutionContext", [](py::handle& obj) {
Safe_TF_StatusPtr status = make_safe(TF_NewStatus()); TFE_Context* ctx =
TFE_ContextOptions options; static_cast<TFE_Context*>(PyCapsule_GetPointer(obj.ptr(), nullptr));
options.use_tfrt = use_tfrt;
auto* ctx = unwrap(TF_NewEagerExecutionContext(&options, status.get()));
MaybeRaiseRegisteredFromTFStatus(status.get());
if (!ctx) { if (!ctx) {
MaybeRaiseRegisteredFromStatus(Internal("Creating eager ctx failed")); MaybeRaiseRegisteredFromStatus(InvalidArgument("TFE_Context is nullptr"));
} }
if (!isa<ImmediateExecutionContext>(ctx)) { return unwrap(ctx);
// TODO(srbs): Add a helper to convert the kind enum to a user-friendly
// string.
MaybeRaiseRegisteredFromStatus(
Internal("TF_NewEagerExecutionContext must return an ",
"ImmediateExecutionContext, found ", ctx->getKind()));
}
return dyn_cast<ImmediateExecutionContext>(ctx);
}); });
// Unified execution context. // Unified execution context.
@ -175,14 +168,14 @@ PYBIND11_MODULE(_unified_api, m) {
return f; return f;
}); });
py::class_<ImmediateExecutionContext, AbstractContext, ImmediateContextPtr>( // Note: This does not take ownership of the C++ context, the lifetime of
m, "ImmediateExecutionContext") // which is managed by the python `Context` and is expected to outlive this
.def("CreateFloatScalarHandle", // object.
[](ImmediateExecutionContext* self, float value) { // TODO(srbs): Make AbstractContext refcounted so that the above comment is
auto* tensor = self->CreateFloatScalar(value); // not needed.
auto* handle = self->CreateLocalHandle(tensor); py::class_<ImmediateExecutionContext, AbstractContext,
return reinterpret_cast<AbstractTensorHandle*>(handle); std::unique_ptr<ImmediateExecutionContext, py::nodelete>>
}); ImmediateExecutionContext(m, "ImmediateExecutionContext");
// Unified execution operation. // Unified execution operation.
py::class_<AbstractOperation, AbstractOperationPtr>(m, "AbstractOperation") py::class_<AbstractOperation, AbstractOperationPtr>(m, "AbstractOperation")
@ -246,5 +239,18 @@ PYBIND11_MODULE(_unified_api, m) {
MaybeRaiseRegisteredFromStatus(s.status); MaybeRaiseRegisteredFromStatus(s.status);
return Pyo(result); return Pyo(result);
}); });
m.def("EagerTensorToImmediateExecutionTensorHandle", [](py::object handle) {
if (!EagerTensor_CheckExact(handle.ptr())) {
MaybeRaiseRegisteredFromStatus(
InvalidArgument("EagerTensorToImmediateExecutionTensorHandle called "
"with non-EagerTensor."));
}
TFE_TensorHandle* eager_tensor = EagerTensor_Handle(handle.ptr());
auto t = static_cast<AbstractTensorHandle*>(unwrap(eager_tensor));
t->Ref();
return t;
});
py::class_<AbstractFunction> AbstractFunction(m, "AbstractFunction"); py::class_<AbstractFunction> AbstractFunction(m, "AbstractFunction");
} }

View File

@ -20,6 +20,9 @@ from __future__ import print_function
from absl.testing import parameterized from absl.testing import parameterized
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework.experimental import _unified_api from tensorflow.python.framework.experimental import _unified_api
from tensorflow.python.framework.experimental import context_stack as context_lib from tensorflow.python.framework.experimental import context_stack as context_lib
from tensorflow.python.framework.experimental import def_function from tensorflow.python.framework.experimental import def_function
@ -27,45 +30,44 @@ from tensorflow.python.framework.experimental import math_ops
from tensorflow.python.framework.experimental import tape as tape_lib from tensorflow.python.framework.experimental import tape as tape_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
NewImmediateExecutionContext = _unified_api.NewImmediateExecutionContext
SetTracingImplementation = _unified_api.SetTracingImplementation SetTracingImplementation = _unified_api.SetTracingImplementation
TensorCastHelper = _unified_api.EagerTensorToImmediateExecutionTensorHandle
def get_immediate_execution_context():
context.context().ensure_initialized()
return _unified_api.EagerContextToImmediateExecutionContext(
context.context()._handle)
class UnifiedApiTest(test.TestCase, parameterized.TestCase): class UnifiedApiTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters([ @parameterized.named_parameters([
("EagerGraph", False, False), ("Graph", False),
("EagerMlir", False, True), ("Mlir", True),
# TODO(srbs): Enable for TFRT. Segfaults right now.
# ("TfrtGraph", True, False),
# ("TfrtMlir", True, True),
]) ])
def testAdd(self, use_tfrt, use_mlir): def testAdd(self, use_mlir):
if use_mlir: if use_mlir:
SetTracingImplementation("mlir") SetTracingImplementation("mlir")
def model(a, b): def model(a, b):
return math_ops.add(a, b) return math_ops.add(a, b)
eager_ctx = NewImmediateExecutionContext(use_tfrt) with context_lib.set_default(get_immediate_execution_context()):
with context_lib.set_default(eager_ctx): a = TensorCastHelper(constant_op.constant([1., 2.]))
a = eager_ctx.CreateFloatScalarHandle(1.) b = TensorCastHelper(constant_op.constant([3., 4.]))
b = eager_ctx.CreateFloatScalarHandle(2.)
func_output = def_function.function(model)(a, b) func_output = def_function.function(model)(a, b)
self.assertAllEqual(func_output.numpy(), 3.0) self.assertAllEqual(func_output.numpy(), [4., 6.])
eager_output = model(a, b) eager_output = model(a, b)
self.assertAllEqual(eager_output.numpy(), 3.0) self.assertAllEqual(eager_output.numpy(), [4., 6.])
@parameterized.named_parameters([ @parameterized.named_parameters([
("EagerGraph", False, False), ("Graph", False),
("EagerMlir", False, True), ("Mlir", True),
# TODO(srbs): Enable for TFRT. Segfaults right now.
# ("TfrtGraph", True, False),
# ("TfrtMlir", True, True),
]) ])
def testAddGrad(self, use_tfrt, use_mlir): def testAddGrad(self, use_mlir):
if use_mlir: if use_mlir:
SetTracingImplementation("mlir") SetTracingImplementation("mlir")
@ -77,19 +79,19 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
grads = tape.gradient(result, [a, b]) grads = tape.gradient(result, [a, b])
return grads return grads
eager_ctx = NewImmediateExecutionContext(use_tfrt) with context_lib.set_default(get_immediate_execution_context()):
with context_lib.set_default(eager_ctx): a = TensorCastHelper(constant_op.constant([1., 2.]))
a = eager_ctx.CreateFloatScalarHandle(1.) b = TensorCastHelper(constant_op.constant([3., 4.]))
b = eager_ctx.CreateFloatScalarHandle(2.)
func_outputs = def_function.function(model)(a, b) func_outputs = def_function.function(model)(a, b)
self.assertAllEqual(func_outputs[0].numpy(), 1.0) self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0])
self.assertAllEqual(func_outputs[1].numpy(), 1.0) self.assertAllEqual(func_outputs[1].numpy(), [1.0, 1.0])
eager_outputs = model(a, b) eager_outputs = model(a, b)
self.assertAllEqual(eager_outputs[0].numpy(), 1.0) self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0])
self.assertAllEqual(eager_outputs[1].numpy(), 1.0) self.assertAllEqual(eager_outputs[1].numpy(), [1.0, 1.0])
if __name__ == "__main__": if __name__ == "__main__":
ops.enable_eager_execution()
test.main() test.main()