1. Add helper for casting EagerTensor* to AbstractTensorHandle*. This allows us to directly work with EagerTensors and will also allow using tf.data with unified APIs.
2. Enable unified_api_test to run with TFRT. PiperOrigin-RevId: 331662847 Change-Id: I78e53f4eb32d40b6fc6672af3caf2a5c3c74a22c
This commit is contained in:
parent
f08c6b6bc1
commit
c11debf86c
tensorflow
@ -111,6 +111,7 @@ filegroup(
|
||||
"mnist_gradients_testutil.h",
|
||||
"tape.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_context_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
"tfe_op_attrs_internal.h",
|
||||
|
@ -24,6 +24,7 @@ bool EagerTensor_CheckExact(const PyObject* o);
|
||||
tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
|
||||
tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
|
||||
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor);
|
||||
TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -120,6 +120,7 @@ cuda_py_test(
|
||||
"no_pip",
|
||||
"no_windows", # b/168218876
|
||||
],
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
":_unified_api",
|
||||
":context_stack",
|
||||
|
@ -22,15 +22,18 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/abstract_function.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/python/eager/pywrap_tensor.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
@ -97,23 +100,13 @@ PYBIND11_MODULE(_unified_api, m) {
|
||||
}
|
||||
return dyn_cast<TracingContext>(ctx);
|
||||
});
|
||||
m.def("NewImmediateExecutionContext", [](bool use_tfrt) {
|
||||
Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
|
||||
TFE_ContextOptions options;
|
||||
options.use_tfrt = use_tfrt;
|
||||
auto* ctx = unwrap(TF_NewEagerExecutionContext(&options, status.get()));
|
||||
MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
m.def("EagerContextToImmediateExecutionContext", [](py::handle& obj) {
|
||||
TFE_Context* ctx =
|
||||
static_cast<TFE_Context*>(PyCapsule_GetPointer(obj.ptr(), nullptr));
|
||||
if (!ctx) {
|
||||
MaybeRaiseRegisteredFromStatus(Internal("Creating eager ctx failed"));
|
||||
MaybeRaiseRegisteredFromStatus(InvalidArgument("TFE_Context is nullptr"));
|
||||
}
|
||||
if (!isa<ImmediateExecutionContext>(ctx)) {
|
||||
// TODO(srbs): Add a helper to convert the kind enum to a user-friendly
|
||||
// string.
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
Internal("TF_NewEagerExecutionContext must return an ",
|
||||
"ImmediateExecutionContext, found ", ctx->getKind()));
|
||||
}
|
||||
return dyn_cast<ImmediateExecutionContext>(ctx);
|
||||
return unwrap(ctx);
|
||||
});
|
||||
|
||||
// Unified execution context.
|
||||
@ -175,14 +168,14 @@ PYBIND11_MODULE(_unified_api, m) {
|
||||
return f;
|
||||
});
|
||||
|
||||
py::class_<ImmediateExecutionContext, AbstractContext, ImmediateContextPtr>(
|
||||
m, "ImmediateExecutionContext")
|
||||
.def("CreateFloatScalarHandle",
|
||||
[](ImmediateExecutionContext* self, float value) {
|
||||
auto* tensor = self->CreateFloatScalar(value);
|
||||
auto* handle = self->CreateLocalHandle(tensor);
|
||||
return reinterpret_cast<AbstractTensorHandle*>(handle);
|
||||
});
|
||||
// Note: This does not take ownership of the C++ context, the lifetime of
|
||||
// which is managed by the python `Context` and is expected to outlive this
|
||||
// object.
|
||||
// TODO(srbs): Make AbstractContext refcounted so that the above comment is
|
||||
// not needed.
|
||||
py::class_<ImmediateExecutionContext, AbstractContext,
|
||||
std::unique_ptr<ImmediateExecutionContext, py::nodelete>>
|
||||
ImmediateExecutionContext(m, "ImmediateExecutionContext");
|
||||
|
||||
// Unified execution operation.
|
||||
py::class_<AbstractOperation, AbstractOperationPtr>(m, "AbstractOperation")
|
||||
@ -246,5 +239,18 @@ PYBIND11_MODULE(_unified_api, m) {
|
||||
MaybeRaiseRegisteredFromStatus(s.status);
|
||||
return Pyo(result);
|
||||
});
|
||||
|
||||
m.def("EagerTensorToImmediateExecutionTensorHandle", [](py::object handle) {
|
||||
if (!EagerTensor_CheckExact(handle.ptr())) {
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
InvalidArgument("EagerTensorToImmediateExecutionTensorHandle called "
|
||||
"with non-EagerTensor."));
|
||||
}
|
||||
TFE_TensorHandle* eager_tensor = EagerTensor_Handle(handle.ptr());
|
||||
auto t = static_cast<AbstractTensorHandle*>(unwrap(eager_tensor));
|
||||
t->Ref();
|
||||
return t;
|
||||
});
|
||||
|
||||
py::class_<AbstractFunction> AbstractFunction(m, "AbstractFunction");
|
||||
}
|
||||
|
@ -20,6 +20,9 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework.experimental import _unified_api
|
||||
from tensorflow.python.framework.experimental import context_stack as context_lib
|
||||
from tensorflow.python.framework.experimental import def_function
|
||||
@ -27,45 +30,44 @@ from tensorflow.python.framework.experimental import math_ops
|
||||
from tensorflow.python.framework.experimental import tape as tape_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
NewImmediateExecutionContext = _unified_api.NewImmediateExecutionContext
|
||||
SetTracingImplementation = _unified_api.SetTracingImplementation
|
||||
TensorCastHelper = _unified_api.EagerTensorToImmediateExecutionTensorHandle
|
||||
|
||||
|
||||
def get_immediate_execution_context():
|
||||
context.context().ensure_initialized()
|
||||
return _unified_api.EagerContextToImmediateExecutionContext(
|
||||
context.context()._handle)
|
||||
|
||||
|
||||
class UnifiedApiTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("EagerGraph", False, False),
|
||||
("EagerMlir", False, True),
|
||||
# TODO(srbs): Enable for TFRT. Segfaults right now.
|
||||
# ("TfrtGraph", True, False),
|
||||
# ("TfrtMlir", True, True),
|
||||
("Graph", False),
|
||||
("Mlir", True),
|
||||
])
|
||||
def testAdd(self, use_tfrt, use_mlir):
|
||||
def testAdd(self, use_mlir):
|
||||
if use_mlir:
|
||||
SetTracingImplementation("mlir")
|
||||
|
||||
def model(a, b):
|
||||
return math_ops.add(a, b)
|
||||
|
||||
eager_ctx = NewImmediateExecutionContext(use_tfrt)
|
||||
with context_lib.set_default(eager_ctx):
|
||||
a = eager_ctx.CreateFloatScalarHandle(1.)
|
||||
b = eager_ctx.CreateFloatScalarHandle(2.)
|
||||
with context_lib.set_default(get_immediate_execution_context()):
|
||||
a = TensorCastHelper(constant_op.constant([1., 2.]))
|
||||
b = TensorCastHelper(constant_op.constant([3., 4.]))
|
||||
|
||||
func_output = def_function.function(model)(a, b)
|
||||
self.assertAllEqual(func_output.numpy(), 3.0)
|
||||
self.assertAllEqual(func_output.numpy(), [4., 6.])
|
||||
|
||||
eager_output = model(a, b)
|
||||
self.assertAllEqual(eager_output.numpy(), 3.0)
|
||||
self.assertAllEqual(eager_output.numpy(), [4., 6.])
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("EagerGraph", False, False),
|
||||
("EagerMlir", False, True),
|
||||
# TODO(srbs): Enable for TFRT. Segfaults right now.
|
||||
# ("TfrtGraph", True, False),
|
||||
# ("TfrtMlir", True, True),
|
||||
("Graph", False),
|
||||
("Mlir", True),
|
||||
])
|
||||
def testAddGrad(self, use_tfrt, use_mlir):
|
||||
def testAddGrad(self, use_mlir):
|
||||
if use_mlir:
|
||||
SetTracingImplementation("mlir")
|
||||
|
||||
@ -77,19 +79,19 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
|
||||
grads = tape.gradient(result, [a, b])
|
||||
return grads
|
||||
|
||||
eager_ctx = NewImmediateExecutionContext(use_tfrt)
|
||||
with context_lib.set_default(eager_ctx):
|
||||
a = eager_ctx.CreateFloatScalarHandle(1.)
|
||||
b = eager_ctx.CreateFloatScalarHandle(2.)
|
||||
with context_lib.set_default(get_immediate_execution_context()):
|
||||
a = TensorCastHelper(constant_op.constant([1., 2.]))
|
||||
b = TensorCastHelper(constant_op.constant([3., 4.]))
|
||||
|
||||
func_outputs = def_function.function(model)(a, b)
|
||||
self.assertAllEqual(func_outputs[0].numpy(), 1.0)
|
||||
self.assertAllEqual(func_outputs[1].numpy(), 1.0)
|
||||
self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0])
|
||||
self.assertAllEqual(func_outputs[1].numpy(), [1.0, 1.0])
|
||||
|
||||
eager_outputs = model(a, b)
|
||||
self.assertAllEqual(eager_outputs[0].numpy(), 1.0)
|
||||
self.assertAllEqual(eager_outputs[1].numpy(), 1.0)
|
||||
self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0])
|
||||
self.assertAllEqual(eager_outputs[1].numpy(), [1.0, 1.0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user