Update TFE_Py_TensorShapeSlice to use abstract interface APIs instead of TF_* APIs.

This is so that the method can be work with TFRT.

PiperOrigin-RevId: 342877130
Change-Id: I1d734daeec17b67c5ae602163d6c1ceb67088fa6
This commit is contained in:
Chuanhao Zhuge 2020-11-17 09:19:02 -08:00 committed by TensorFlower Gardener
parent 02df2f9599
commit 7afd763939
3 changed files with 68 additions and 38 deletions

View File

@ -153,8 +153,10 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
Status* status) override;
ImmediateExecutionOperation* CreateOperation() override;
// Convert a TFRT TensorHandle to tensorflow::TensorHandle. In this case,
// just forward the input TensorHandle.
// This is a virtual helper function to convert TFRT TensorHandle to
// tensorflow::TensorHandle. In current runtime EagerContext, just forward
// the input since the input tensor handle is already a
// tensorflow::TensorHandle.
ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
ImmediateExecutionTensorHandle* handle) override;

View File

@ -25,6 +25,8 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
@ -1030,49 +1032,71 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
return nullptr;
}
PyObject* py_context = GetPyEagerContext();
if (py_context == nullptr) {
PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat(
"Cannot create EagerTensor when "
"EagerContext is not valid")
.c_str());
return nullptr;
}
TFE_Context* ctx = GetContextHandle(py_context);
Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
PyObject** tensors_array = PySequence_Fast_ITEMS(tensors);
int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
auto tensor = tensorflow::make_safe(TF_AllocateTensor(
TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
auto status = tensorflow::make_safe(TF_NewStatus());
for (Py_ssize_t i = 0; i < num_tensors; ++i) {
PyObject* tensor_obj = tensors_array[i];
if (!EagerTensor_CheckExact(tensor_obj)) {
PyErr_SetString(PyExc_TypeError,
tensorflow::strings::StrCat(
"Expected a list of EagerTensors but "
"element ",
i, " has type \"", Py_TYPE(tensor_obj)->tp_name, "\"")
.c_str());
return nullptr;
}
EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
TFE_TensorHandle* handle = t->handle;
int num_dims = TFE_TensorHandleNumDims(handle, status.get());
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
return nullptr;
auto status = tensorflow::make_safe(TF_NewStatus());
// Create an empty tensor.
auto* tensor = tensorflow::unwrap(ctx)->CreateTensor(
tensorflow::DT_INT32, /*dim_sizes=*/{num_tensors_int});
if (num_tensors_int > 0) {
int32_t* data = reinterpret_cast<int32_t*>(tensor->Data());
// Fill the tensor with dims.
for (Py_ssize_t i = 0; i < num_tensors; ++i) {
PyObject* tensor_obj = tensors_array[i];
if (!EagerTensor_CheckExact(tensor_obj)) {
PyErr_SetString(
PyExc_TypeError,
tensorflow::strings::StrCat("Expected a list of EagerTensors but "
"element ",
i, " has type \"",
Py_TYPE(tensor_obj)->tp_name, "\"")
.c_str());
return nullptr;
}
EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
TFE_TensorHandle* handle = t->handle;
int num_dims = TFE_TensorHandleNumDims(handle, status.get());
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
return nullptr;
}
if (slice_dim >= num_dims) {
PyErr_SetString(
PyExc_IndexError,
tensorflow::strings::StrCat("Slice dimension (", slice_dim,
") must be smaller than rank of all "
"tensors, but tensor at index ",
i, " has rank ", num_dims)
.c_str());
return nullptr;
}
int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
return nullptr;
}
data[i] = dim;
}
if (slice_dim >= num_dims) {
PyErr_SetString(
PyExc_IndexError,
tensorflow::strings::StrCat("Slice dimension (", slice_dim,
") must be smaller than rank of all "
"tensors, but tensor at index ",
i, " has rank ", num_dims)
.c_str());
return nullptr;
}
int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
return nullptr;
}
data[i] = dim;
}
TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
TFE_TensorHandle* handle =
tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(tensor));
if (!status->status.ok()) {
PyErr_SetString(
PyExc_RuntimeError,

View File

@ -436,6 +436,10 @@ class TFETensorTest(test_util.TensorFlowTestCase):
class TFETensorUtilTest(test_util.TensorFlowTestCase):
def setUp(self):
super(TFETensorUtilTest, self).setUp()
context.ensure_initialized()
def testListOfThree(self):
t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32)