Update TFE_Py_TensorShapeSlice to use abstract interface APIs instead of TF_* APIs.
This is so that the method can be work with TFRT. PiperOrigin-RevId: 342877130 Change-Id: I1d734daeec17b67c5ae602163d6c1ceb67088fa6
This commit is contained in:
parent
02df2f9599
commit
7afd763939
@ -153,8 +153,10 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
Status* status) override;
|
||||
ImmediateExecutionOperation* CreateOperation() override;
|
||||
|
||||
// Convert a TFRT TensorHandle to tensorflow::TensorHandle. In this case,
|
||||
// just forward the input TensorHandle.
|
||||
// This is a virtual helper function to convert TFRT TensorHandle to
|
||||
// tensorflow::TensorHandle. In current runtime EagerContext, just forward
|
||||
// the input since the input tensor handle is already a
|
||||
// tensorflow::TensorHandle.
|
||||
ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
|
||||
ImmediateExecutionTensorHandle* handle) override;
|
||||
|
||||
|
@ -25,6 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
@ -1030,49 +1032,71 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* py_context = GetPyEagerContext();
|
||||
if (py_context == nullptr) {
|
||||
PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat(
|
||||
"Cannot create EagerTensor when "
|
||||
"EagerContext is not valid")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_Context* ctx = GetContextHandle(py_context);
|
||||
|
||||
Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
|
||||
PyObject** tensors_array = PySequence_Fast_ITEMS(tensors);
|
||||
int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
|
||||
auto tensor = tensorflow::make_safe(TF_AllocateTensor(
|
||||
TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
|
||||
int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
|
||||
auto status = tensorflow::make_safe(TF_NewStatus());
|
||||
for (Py_ssize_t i = 0; i < num_tensors; ++i) {
|
||||
PyObject* tensor_obj = tensors_array[i];
|
||||
if (!EagerTensor_CheckExact(tensor_obj)) {
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Expected a list of EagerTensors but "
|
||||
"element ",
|
||||
i, " has type \"", Py_TYPE(tensor_obj)->tp_name, "\"")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
|
||||
TFE_TensorHandle* handle = t->handle;
|
||||
int num_dims = TFE_TensorHandleNumDims(handle, status.get());
|
||||
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
|
||||
return nullptr;
|
||||
auto status = tensorflow::make_safe(TF_NewStatus());
|
||||
|
||||
// Create an empty tensor.
|
||||
auto* tensor = tensorflow::unwrap(ctx)->CreateTensor(
|
||||
tensorflow::DT_INT32, /*dim_sizes=*/{num_tensors_int});
|
||||
|
||||
if (num_tensors_int > 0) {
|
||||
int32_t* data = reinterpret_cast<int32_t*>(tensor->Data());
|
||||
|
||||
// Fill the tensor with dims.
|
||||
for (Py_ssize_t i = 0; i < num_tensors; ++i) {
|
||||
PyObject* tensor_obj = tensors_array[i];
|
||||
if (!EagerTensor_CheckExact(tensor_obj)) {
|
||||
PyErr_SetString(
|
||||
PyExc_TypeError,
|
||||
tensorflow::strings::StrCat("Expected a list of EagerTensors but "
|
||||
"element ",
|
||||
i, " has type \"",
|
||||
Py_TYPE(tensor_obj)->tp_name, "\"")
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
|
||||
TFE_TensorHandle* handle = t->handle;
|
||||
int num_dims = TFE_TensorHandleNumDims(handle, status.get());
|
||||
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (slice_dim >= num_dims) {
|
||||
PyErr_SetString(
|
||||
PyExc_IndexError,
|
||||
tensorflow::strings::StrCat("Slice dimension (", slice_dim,
|
||||
") must be smaller than rank of all "
|
||||
"tensors, but tensor at index ",
|
||||
i, " has rank ", num_dims)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
|
||||
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
|
||||
return nullptr;
|
||||
}
|
||||
data[i] = dim;
|
||||
}
|
||||
if (slice_dim >= num_dims) {
|
||||
PyErr_SetString(
|
||||
PyExc_IndexError,
|
||||
tensorflow::strings::StrCat("Slice dimension (", slice_dim,
|
||||
") must be smaller than rank of all "
|
||||
"tensors, but tensor at index ",
|
||||
i, " has rank ", num_dims)
|
||||
.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
|
||||
if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
|
||||
return nullptr;
|
||||
}
|
||||
data[i] = dim;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
|
||||
TFE_TensorHandle* handle =
|
||||
tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(tensor));
|
||||
|
||||
if (!status->status.ok()) {
|
||||
PyErr_SetString(
|
||||
PyExc_RuntimeError,
|
||||
|
@ -436,6 +436,10 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TFETensorUtilTest, self).setUp()
|
||||
context.ensure_initialized()
|
||||
|
||||
def testListOfThree(self):
|
||||
t1 = _create_tensor([[1, 2], [3, 4], [5, 6]], dtype=dtypes.int32)
|
||||
t2 = _create_tensor([[1, 2, 5], [3, 4, 5]], dtype=dtypes.int32)
|
||||
|
Loading…
Reference in New Issue
Block a user