Change PySeqToTensor to return TFE_TensorHandle

PiperOrigin-RevId: 289108443
Change-Id: I2aac99acb068b0dae2f8aabf72e323d0d303ebb1
This commit is contained in:
Gaurav Jain 2020-01-10 09:40:28 -08:00 committed by TensorFlower Gardener
parent d88b067ef1
commit 8a33966dbf
4 changed files with 92 additions and 84 deletions

View File

@ -825,6 +825,7 @@ cc_library(
":numpy_lib",
":py_util",
":safe_ptr",
"//tensorflow/c/eager:c_api_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/python_runtime:headers",

View File

@ -252,25 +252,6 @@ TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
#undef RETURN_ERROR
}
TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* value,
DataType dtype) {
tensorflow::TensorHandle* handle = nullptr;
tensorflow::Tensor t;
// TODO(josh11b): Have PySeqToTensor set python errors instead of
// returning Status.
auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
if (cppstatus.ok()) {
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
t, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
}
if (!cppstatus.ok()) {
PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
return nullptr;
}
CHECK_NE(handle, nullptr);
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)};
}
TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
PyObject* value,
tensorflow::DataType dtype,

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_seq_tensor.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
@ -67,7 +68,7 @@ bool IsPyFloat(PyObject* obj) {
struct ConverterState {
// The inferred tensor shape.
TensorShape inferred_shape;
gtl::InlinedVector<int64, 4> inferred_shape;
// The inferred tensor data type.
DataType inferred_dtype;
@ -155,14 +156,14 @@ Status InferShapeAndType(PyObject* obj, ConverterState* state) {
} else if (PySequence_Check(obj)) {
auto length = PySequence_Length(obj);
if (length > 0) {
state->inferred_shape.AddDim(length);
state->inferred_shape.push_back(length);
PyObject* elem = nullptr;
TF_RETURN_IF_ERROR(SampleElementFromSequence(obj, &elem));
obj = elem;
refs_to_clean.push_back(make_safe(obj));
continue;
} else if (length == 0) {
state->inferred_shape.AddDim(length);
state->inferred_shape.push_back(length);
state->inferred_dtype = DT_INVALID; // Invalid dtype for empty tensors.
} else {
// The sequence does not have a valid length (PySequence_Length < 0).
@ -247,12 +248,12 @@ struct Converter {
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, ""));
if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular;
const int64 s = state->inferred_shape.dim_size(depth);
const int64 s = state->inferred_shape[depth];
if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) {
return ErrorRectangular;
}
if (state->inferred_shape.dims() - depth > 1) {
if (state->inferred_shape.size() - depth > 1) {
/* Iterate over outer dim, and recursively convert each element. */
for (int64 i = 0; i < s; ++i) {
const char* error = Helper(PySequence_Fast_GET_ITEM(seq.get(), i),
@ -272,24 +273,31 @@ struct Converter {
return nullptr;
}
static const char* Convert(PyObject* obj, ConverterState* state,
Tensor* dest) {
static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state,
TFE_TensorHandle** h, const char** error) {
/* TODO(josh11b): Allocator & attributes? */
Tensor result(ConverterTraits<T>::kTypeEnum, state->inferred_shape);
if (state->inferred_shape.dims() == 0) { /* Scalar case */
Tensor result(ConverterTraits<T>::kTypeEnum,
TensorShape(state->inferred_shape));
if (state->inferred_shape.empty()) { /* Scalar case */
T value;
auto scalar = ZeroDimArrayToScalar(obj, state);
const char* error = ConverterTraits<T>::ConvertScalar(scalar, &value);
*error = ConverterTraits<T>::ConvertScalar(scalar, &value);
Py_DECREF(scalar);
if (error != nullptr) return error;
if (*error != nullptr) return errors::InvalidArgument(*error);
result.scalar<T>()() = value;
} else {
T* buf = result.flat<T>().data();
const char* error = Helper(obj, 0, state, &buf);
if (error != nullptr) return error;
*error = Helper(obj, 0, state, &buf);
if (*error != nullptr) return errors::InvalidArgument(*error);
}
*dest = result;
return nullptr;
tensorflow::TensorHandle* handle = nullptr;
auto status = tensorflow::TensorHandle::CreateLocalHandle(
result, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
if (!status.ok()) {
return status;
}
*h = new TFE_TensorHandle{TensorHandleInterface(handle)};
return Status::OK();
}
};
@ -592,16 +600,14 @@ typedef Converter<bool> BoolConverter;
} // namespace
#define RETURN_STRING_AS_STATUS(...) \
do { \
const char* _error = (__VA_ARGS__); \
if (TF_PREDICT_TRUE(_error == nullptr)) return Status::OK(); \
return errors::InvalidArgument(_error); \
} while (0)
Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) {
TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
DataType dtype) {
ConverterState state;
TF_RETURN_IF_ERROR(InferShapeAndType(obj, &state));
Status status = InferShapeAndType(obj, &state);
if (!status.ok()) {
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
return nullptr;
}
DataType requested_dtype = DT_INVALID;
if (dtype != DT_INVALID) {
requested_dtype = dtype;
@ -610,116 +616,131 @@ Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) {
// we just try instead to create a tensor of the inferred type and
// let the caller convert it to the requested type using a cast
// operation.
const char* error = nullptr;
TFE_TensorHandle* handle = nullptr;
status = errors::Unimplemented("Missing Python -> Tensor conversion for ",
DataTypeString(state.inferred_dtype));
switch (requested_dtype) {
case DT_FLOAT:
if (FloatConverter::Convert(obj, &state, ret) == nullptr)
return Status::OK();
status = FloatConverter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_DOUBLE:
if (DoubleConverter::Convert(obj, &state, ret) == nullptr)
return Status::OK();
status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_HALF:
if (NumpyHalfConverter::Convert(obj, &state, ret) == nullptr)
return Status::OK();
status = NumpyHalfConverter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_INT64:
if (Int64Converter::Convert(obj, &state, ret) == nullptr)
return Status::OK();
status = Int64Converter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_INT32:
if (Int32Converter::Convert(obj, &state, ret) == nullptr)
return Status::OK();
status = Int32Converter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_UINT64:
if (UInt64Converter::Convert(obj, &state, ret) == nullptr)
return Status::OK();
status = UInt64Converter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_COMPLEX128:
if (Complex128Converter::Convert(obj, &state, ret) == nullptr)
return Status::OK();
status = Complex128Converter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_STRING:
if (StringConverter::Convert(obj, &state, ret) == nullptr)
return Status::OK();
status = StringConverter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_BOOL:
if (BoolConverter::Convert(obj, &state, ret) == nullptr)
return Status::OK();
status = BoolConverter::Convert(ctx, obj, &state, &handle, &error);
break;
default:
break;
}
if (status.ok()) return handle;
switch (state.inferred_dtype) {
case DT_FLOAT:
// TODO(josh11b): Handle mixed floats and complex numbers?
if (requested_dtype == DT_INVALID) {
// TensorFlow uses float32s to represent floating point numbers
// by default (for space and speed over using doubles).
RETURN_STRING_AS_STATUS(FloatConverter::Convert(obj, &state, ret));
status = FloatConverter::Convert(ctx, obj, &state, &handle, &error);
} else {
// We are going to do a cast to the user's requested dtype
// after this. We use doubles for this intermediate result so
// we don't lose precision that might be representable in the
// final type.
RETURN_STRING_AS_STATUS(DoubleConverter::Convert(obj, &state, ret));
status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error);
}
break;
case DT_DOUBLE:
RETURN_STRING_AS_STATUS(DoubleConverter::Convert(obj, &state, ret));
status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_HALF:
RETURN_STRING_AS_STATUS(NumpyHalfConverter::Convert(obj, &state, ret));
status = NumpyHalfConverter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_INT64:
if (requested_dtype == DT_INVALID) {
const char* error = Int32Converter::Convert(obj, &state, ret);
status = Int32Converter::Convert(ctx, obj, &state, &handle, &error);
if (error == ErrorFoundInt64) {
error = Int64Converter::Convert(obj, &state, ret);
status = Int64Converter::Convert(ctx, obj, &state, &handle, &error);
}
if (error == ErrorFoundFloat) {
error = FloatConverter::Convert(obj, &state, ret);
status = FloatConverter::Convert(ctx, obj, &state, &handle, &error);
}
// TODO(josh11b): May also want to fall back to using doubles if
// error == ErrorOutOfRange?
RETURN_STRING_AS_STATUS(error);
} else {
const char* error = Int64Converter::Convert(obj, &state, ret);
status = Int64Converter::Convert(ctx, obj, &state, &handle, &error);
if (error == ErrorFoundFloat) {
error = DoubleConverter::Convert(obj, &state, ret);
status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error);
}
RETURN_STRING_AS_STATUS(error);
}
break;
case DT_STRING:
RETURN_STRING_AS_STATUS(StringConverter::Convert(obj, &state, ret));
status = StringConverter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_COMPLEX128:
RETURN_STRING_AS_STATUS(Complex128Converter::Convert(obj, &state, ret));
status = Complex128Converter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_BOOL:
RETURN_STRING_AS_STATUS(BoolConverter::Convert(obj, &state, ret));
status = BoolConverter::Convert(ctx, obj, &state, &handle, &error);
break;
case DT_INVALID: // Only occurs for empty tensors.
*ret = Tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
state.inferred_shape);
return Status::OK();
{
tensorflow::TensorHandle* h = nullptr;
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
TensorShape(state.inferred_shape));
status = tensorflow::TensorHandle::CreateLocalHandle(
tensor, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &h);
if (!status.ok()) {
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
return nullptr;
}
return new TFE_TensorHandle{TensorHandleInterface(h)};
}
default:
return errors::Unimplemented("Missing Python -> Tensor conversion for ",
DataTypeString(state.inferred_dtype));
break;
}
return Status::OK();
if (!status.ok()) {
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
return nullptr;
}
return handle;
}
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <Python.h>
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@ -25,12 +26,16 @@ namespace tensorflow {
// Converts Python object `obj` representing a rectangular array of
// Python values (a scalar, a sequence of scalars, a sequence of
// sequences, etc.) into a C++ TensorFlow Tensor and stores it in
// *ret. If dtype is not None it should by a Python integer
// sequences, etc.) into a TFE_TensorHandle.
// If dtype is not None it should by a Python integer
// representing the desired dtype of the resulting Tensor.
// This is used only as a hint, *ret may not have that dtype on
// success and may require a cast.
Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret);
//
// If an error occurs, this return nullptr and sets the python error indicator
// with PyErr_SetString.
TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
DataType dtype);
} // namespace tensorflow