1. Add helper for casting EagerTensor* to AbstractTensorHandle*. This allows us to directly work with EagerTensors and will also allow using tf.data with unified APIs.
2. Enable unified_api_test to run with TFRT. PiperOrigin-RevId: 331662847 Change-Id: I78e53f4eb32d40b6fc6672af3caf2a5c3c74a22c
This commit is contained in:
parent
f08c6b6bc1
commit
c11debf86c
@ -111,6 +111,7 @@ filegroup(
|
|||||||
"mnist_gradients_testutil.h",
|
"mnist_gradients_testutil.h",
|
||||||
"tape.h",
|
"tape.h",
|
||||||
"tfe_cancellation_manager_internal.h",
|
"tfe_cancellation_manager_internal.h",
|
||||||
|
"tfe_context_internal.h",
|
||||||
"tfe_executor_internal.h",
|
"tfe_executor_internal.h",
|
||||||
"tfe_monitoring_internal.h",
|
"tfe_monitoring_internal.h",
|
||||||
"tfe_op_attrs_internal.h",
|
"tfe_op_attrs_internal.h",
|
||||||
|
|||||||
@ -24,6 +24,7 @@ bool EagerTensor_CheckExact(const PyObject* o);
|
|||||||
tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
|
tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
|
||||||
tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
|
tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
|
||||||
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor);
|
tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor);
|
||||||
|
TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
|||||||
@ -120,6 +120,7 @@ cuda_py_test(
|
|||||||
"no_pip",
|
"no_pip",
|
||||||
"no_windows", # b/168218876
|
"no_windows", # b/168218876
|
||||||
],
|
],
|
||||||
|
tfrt_enabled = True,
|
||||||
deps = [
|
deps = [
|
||||||
":_unified_api",
|
":_unified_api",
|
||||||
":context_stack",
|
":context_stack",
|
||||||
|
|||||||
@ -22,15 +22,18 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/abstract_function.h"
|
#include "tensorflow/c/eager/abstract_function.h"
|
||||||
#include "tensorflow/c/eager/abstract_operation.h"
|
#include "tensorflow/c/eager/abstract_operation.h"
|
||||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/python/eager/pywrap_tensor.h"
|
#include "tensorflow/python/eager/pywrap_tensor.h"
|
||||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
@ -97,23 +100,13 @@ PYBIND11_MODULE(_unified_api, m) {
|
|||||||
}
|
}
|
||||||
return dyn_cast<TracingContext>(ctx);
|
return dyn_cast<TracingContext>(ctx);
|
||||||
});
|
});
|
||||||
m.def("NewImmediateExecutionContext", [](bool use_tfrt) {
|
m.def("EagerContextToImmediateExecutionContext", [](py::handle& obj) {
|
||||||
Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
|
TFE_Context* ctx =
|
||||||
TFE_ContextOptions options;
|
static_cast<TFE_Context*>(PyCapsule_GetPointer(obj.ptr(), nullptr));
|
||||||
options.use_tfrt = use_tfrt;
|
|
||||||
auto* ctx = unwrap(TF_NewEagerExecutionContext(&options, status.get()));
|
|
||||||
MaybeRaiseRegisteredFromTFStatus(status.get());
|
|
||||||
if (!ctx) {
|
if (!ctx) {
|
||||||
MaybeRaiseRegisteredFromStatus(Internal("Creating eager ctx failed"));
|
MaybeRaiseRegisteredFromStatus(InvalidArgument("TFE_Context is nullptr"));
|
||||||
}
|
}
|
||||||
if (!isa<ImmediateExecutionContext>(ctx)) {
|
return unwrap(ctx);
|
||||||
// TODO(srbs): Add a helper to convert the kind enum to a user-friendly
|
|
||||||
// string.
|
|
||||||
MaybeRaiseRegisteredFromStatus(
|
|
||||||
Internal("TF_NewEagerExecutionContext must return an ",
|
|
||||||
"ImmediateExecutionContext, found ", ctx->getKind()));
|
|
||||||
}
|
|
||||||
return dyn_cast<ImmediateExecutionContext>(ctx);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Unified execution context.
|
// Unified execution context.
|
||||||
@ -175,14 +168,14 @@ PYBIND11_MODULE(_unified_api, m) {
|
|||||||
return f;
|
return f;
|
||||||
});
|
});
|
||||||
|
|
||||||
py::class_<ImmediateExecutionContext, AbstractContext, ImmediateContextPtr>(
|
// Note: This does not take ownership of the C++ context, the lifetime of
|
||||||
m, "ImmediateExecutionContext")
|
// which is managed by the python `Context` and is expected to outlive this
|
||||||
.def("CreateFloatScalarHandle",
|
// object.
|
||||||
[](ImmediateExecutionContext* self, float value) {
|
// TODO(srbs): Make AbstractContext refcounted so that the above comment is
|
||||||
auto* tensor = self->CreateFloatScalar(value);
|
// not needed.
|
||||||
auto* handle = self->CreateLocalHandle(tensor);
|
py::class_<ImmediateExecutionContext, AbstractContext,
|
||||||
return reinterpret_cast<AbstractTensorHandle*>(handle);
|
std::unique_ptr<ImmediateExecutionContext, py::nodelete>>
|
||||||
});
|
ImmediateExecutionContext(m, "ImmediateExecutionContext");
|
||||||
|
|
||||||
// Unified execution operation.
|
// Unified execution operation.
|
||||||
py::class_<AbstractOperation, AbstractOperationPtr>(m, "AbstractOperation")
|
py::class_<AbstractOperation, AbstractOperationPtr>(m, "AbstractOperation")
|
||||||
@ -246,5 +239,18 @@ PYBIND11_MODULE(_unified_api, m) {
|
|||||||
MaybeRaiseRegisteredFromStatus(s.status);
|
MaybeRaiseRegisteredFromStatus(s.status);
|
||||||
return Pyo(result);
|
return Pyo(result);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
m.def("EagerTensorToImmediateExecutionTensorHandle", [](py::object handle) {
|
||||||
|
if (!EagerTensor_CheckExact(handle.ptr())) {
|
||||||
|
MaybeRaiseRegisteredFromStatus(
|
||||||
|
InvalidArgument("EagerTensorToImmediateExecutionTensorHandle called "
|
||||||
|
"with non-EagerTensor."));
|
||||||
|
}
|
||||||
|
TFE_TensorHandle* eager_tensor = EagerTensor_Handle(handle.ptr());
|
||||||
|
auto t = static_cast<AbstractTensorHandle*>(unwrap(eager_tensor));
|
||||||
|
t->Ref();
|
||||||
|
return t;
|
||||||
|
});
|
||||||
|
|
||||||
py::class_<AbstractFunction> AbstractFunction(m, "AbstractFunction");
|
py::class_<AbstractFunction> AbstractFunction(m, "AbstractFunction");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,6 +20,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework.experimental import _unified_api
|
from tensorflow.python.framework.experimental import _unified_api
|
||||||
from tensorflow.python.framework.experimental import context_stack as context_lib
|
from tensorflow.python.framework.experimental import context_stack as context_lib
|
||||||
from tensorflow.python.framework.experimental import def_function
|
from tensorflow.python.framework.experimental import def_function
|
||||||
@ -27,45 +30,44 @@ from tensorflow.python.framework.experimental import math_ops
|
|||||||
from tensorflow.python.framework.experimental import tape as tape_lib
|
from tensorflow.python.framework.experimental import tape as tape_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
NewImmediateExecutionContext = _unified_api.NewImmediateExecutionContext
|
|
||||||
SetTracingImplementation = _unified_api.SetTracingImplementation
|
SetTracingImplementation = _unified_api.SetTracingImplementation
|
||||||
|
TensorCastHelper = _unified_api.EagerTensorToImmediateExecutionTensorHandle
|
||||||
|
|
||||||
|
|
||||||
|
def get_immediate_execution_context():
|
||||||
|
context.context().ensure_initialized()
|
||||||
|
return _unified_api.EagerContextToImmediateExecutionContext(
|
||||||
|
context.context()._handle)
|
||||||
|
|
||||||
|
|
||||||
class UnifiedApiTest(test.TestCase, parameterized.TestCase):
|
class UnifiedApiTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@parameterized.named_parameters([
|
@parameterized.named_parameters([
|
||||||
("EagerGraph", False, False),
|
("Graph", False),
|
||||||
("EagerMlir", False, True),
|
("Mlir", True),
|
||||||
# TODO(srbs): Enable for TFRT. Segfaults right now.
|
|
||||||
# ("TfrtGraph", True, False),
|
|
||||||
# ("TfrtMlir", True, True),
|
|
||||||
])
|
])
|
||||||
def testAdd(self, use_tfrt, use_mlir):
|
def testAdd(self, use_mlir):
|
||||||
if use_mlir:
|
if use_mlir:
|
||||||
SetTracingImplementation("mlir")
|
SetTracingImplementation("mlir")
|
||||||
|
|
||||||
def model(a, b):
|
def model(a, b):
|
||||||
return math_ops.add(a, b)
|
return math_ops.add(a, b)
|
||||||
|
|
||||||
eager_ctx = NewImmediateExecutionContext(use_tfrt)
|
with context_lib.set_default(get_immediate_execution_context()):
|
||||||
with context_lib.set_default(eager_ctx):
|
a = TensorCastHelper(constant_op.constant([1., 2.]))
|
||||||
a = eager_ctx.CreateFloatScalarHandle(1.)
|
b = TensorCastHelper(constant_op.constant([3., 4.]))
|
||||||
b = eager_ctx.CreateFloatScalarHandle(2.)
|
|
||||||
|
|
||||||
func_output = def_function.function(model)(a, b)
|
func_output = def_function.function(model)(a, b)
|
||||||
self.assertAllEqual(func_output.numpy(), 3.0)
|
self.assertAllEqual(func_output.numpy(), [4., 6.])
|
||||||
|
|
||||||
eager_output = model(a, b)
|
eager_output = model(a, b)
|
||||||
self.assertAllEqual(eager_output.numpy(), 3.0)
|
self.assertAllEqual(eager_output.numpy(), [4., 6.])
|
||||||
|
|
||||||
@parameterized.named_parameters([
|
@parameterized.named_parameters([
|
||||||
("EagerGraph", False, False),
|
("Graph", False),
|
||||||
("EagerMlir", False, True),
|
("Mlir", True),
|
||||||
# TODO(srbs): Enable for TFRT. Segfaults right now.
|
|
||||||
# ("TfrtGraph", True, False),
|
|
||||||
# ("TfrtMlir", True, True),
|
|
||||||
])
|
])
|
||||||
def testAddGrad(self, use_tfrt, use_mlir):
|
def testAddGrad(self, use_mlir):
|
||||||
if use_mlir:
|
if use_mlir:
|
||||||
SetTracingImplementation("mlir")
|
SetTracingImplementation("mlir")
|
||||||
|
|
||||||
@ -77,19 +79,19 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
|
|||||||
grads = tape.gradient(result, [a, b])
|
grads = tape.gradient(result, [a, b])
|
||||||
return grads
|
return grads
|
||||||
|
|
||||||
eager_ctx = NewImmediateExecutionContext(use_tfrt)
|
with context_lib.set_default(get_immediate_execution_context()):
|
||||||
with context_lib.set_default(eager_ctx):
|
a = TensorCastHelper(constant_op.constant([1., 2.]))
|
||||||
a = eager_ctx.CreateFloatScalarHandle(1.)
|
b = TensorCastHelper(constant_op.constant([3., 4.]))
|
||||||
b = eager_ctx.CreateFloatScalarHandle(2.)
|
|
||||||
|
|
||||||
func_outputs = def_function.function(model)(a, b)
|
func_outputs = def_function.function(model)(a, b)
|
||||||
self.assertAllEqual(func_outputs[0].numpy(), 1.0)
|
self.assertAllEqual(func_outputs[0].numpy(), [1.0, 1.0])
|
||||||
self.assertAllEqual(func_outputs[1].numpy(), 1.0)
|
self.assertAllEqual(func_outputs[1].numpy(), [1.0, 1.0])
|
||||||
|
|
||||||
eager_outputs = model(a, b)
|
eager_outputs = model(a, b)
|
||||||
self.assertAllEqual(eager_outputs[0].numpy(), 1.0)
|
self.assertAllEqual(eager_outputs[0].numpy(), [1.0, 1.0])
|
||||||
self.assertAllEqual(eager_outputs[1].numpy(), 1.0)
|
self.assertAllEqual(eager_outputs[1].numpy(), [1.0, 1.0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
ops.enable_eager_execution()
|
||||||
test.main()
|
test.main()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user