Change PySeqToTensor to return TFE_TensorHandle
PiperOrigin-RevId: 289108443 Change-Id: I2aac99acb068b0dae2f8aabf72e323d0d303ebb1
This commit is contained in:
parent
d88b067ef1
commit
8a33966dbf
tensorflow/python
@ -825,6 +825,7 @@ cc_library(
|
||||
":numpy_lib",
|
||||
":py_util",
|
||||
":safe_ptr",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/python_runtime:headers",
|
||||
|
@ -252,25 +252,6 @@ TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
|
||||
#undef RETURN_ERROR
|
||||
}
|
||||
|
||||
TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* value,
|
||||
DataType dtype) {
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
tensorflow::Tensor t;
|
||||
// TODO(josh11b): Have PySeqToTensor set python errors instead of
|
||||
// returning Status.
|
||||
auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
|
||||
if (cppstatus.ok()) {
|
||||
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
|
||||
}
|
||||
if (!cppstatus.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
CHECK_NE(handle, nullptr);
|
||||
return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)};
|
||||
}
|
||||
|
||||
TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
|
||||
PyObject* value,
|
||||
tensorflow::DataType dtype,
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/python/lib/core/py_seq_tensor.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -67,7 +68,7 @@ bool IsPyFloat(PyObject* obj) {
|
||||
|
||||
struct ConverterState {
|
||||
// The inferred tensor shape.
|
||||
TensorShape inferred_shape;
|
||||
gtl::InlinedVector<int64, 4> inferred_shape;
|
||||
|
||||
// The inferred tensor data type.
|
||||
DataType inferred_dtype;
|
||||
@ -155,14 +156,14 @@ Status InferShapeAndType(PyObject* obj, ConverterState* state) {
|
||||
} else if (PySequence_Check(obj)) {
|
||||
auto length = PySequence_Length(obj);
|
||||
if (length > 0) {
|
||||
state->inferred_shape.AddDim(length);
|
||||
state->inferred_shape.push_back(length);
|
||||
PyObject* elem = nullptr;
|
||||
TF_RETURN_IF_ERROR(SampleElementFromSequence(obj, &elem));
|
||||
obj = elem;
|
||||
refs_to_clean.push_back(make_safe(obj));
|
||||
continue;
|
||||
} else if (length == 0) {
|
||||
state->inferred_shape.AddDim(length);
|
||||
state->inferred_shape.push_back(length);
|
||||
state->inferred_dtype = DT_INVALID; // Invalid dtype for empty tensors.
|
||||
} else {
|
||||
// The sequence does not have a valid length (PySequence_Length < 0).
|
||||
@ -247,12 +248,12 @@ struct Converter {
|
||||
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, ""));
|
||||
if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular;
|
||||
|
||||
const int64 s = state->inferred_shape.dim_size(depth);
|
||||
const int64 s = state->inferred_shape[depth];
|
||||
if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) {
|
||||
return ErrorRectangular;
|
||||
}
|
||||
|
||||
if (state->inferred_shape.dims() - depth > 1) {
|
||||
if (state->inferred_shape.size() - depth > 1) {
|
||||
/* Iterate over outer dim, and recursively convert each element. */
|
||||
for (int64 i = 0; i < s; ++i) {
|
||||
const char* error = Helper(PySequence_Fast_GET_ITEM(seq.get(), i),
|
||||
@ -272,24 +273,31 @@ struct Converter {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const char* Convert(PyObject* obj, ConverterState* state,
|
||||
Tensor* dest) {
|
||||
static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state,
|
||||
TFE_TensorHandle** h, const char** error) {
|
||||
/* TODO(josh11b): Allocator & attributes? */
|
||||
Tensor result(ConverterTraits<T>::kTypeEnum, state->inferred_shape);
|
||||
if (state->inferred_shape.dims() == 0) { /* Scalar case */
|
||||
Tensor result(ConverterTraits<T>::kTypeEnum,
|
||||
TensorShape(state->inferred_shape));
|
||||
if (state->inferred_shape.empty()) { /* Scalar case */
|
||||
T value;
|
||||
auto scalar = ZeroDimArrayToScalar(obj, state);
|
||||
const char* error = ConverterTraits<T>::ConvertScalar(scalar, &value);
|
||||
*error = ConverterTraits<T>::ConvertScalar(scalar, &value);
|
||||
Py_DECREF(scalar);
|
||||
if (error != nullptr) return error;
|
||||
if (*error != nullptr) return errors::InvalidArgument(*error);
|
||||
result.scalar<T>()() = value;
|
||||
} else {
|
||||
T* buf = result.flat<T>().data();
|
||||
const char* error = Helper(obj, 0, state, &buf);
|
||||
if (error != nullptr) return error;
|
||||
*error = Helper(obj, 0, state, &buf);
|
||||
if (*error != nullptr) return errors::InvalidArgument(*error);
|
||||
}
|
||||
*dest = result;
|
||||
return nullptr;
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
auto status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
result, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
*h = new TFE_TensorHandle{TensorHandleInterface(handle)};
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
@ -592,16 +600,14 @@ typedef Converter<bool> BoolConverter;
|
||||
|
||||
} // namespace
|
||||
|
||||
#define RETURN_STRING_AS_STATUS(...) \
|
||||
do { \
|
||||
const char* _error = (__VA_ARGS__); \
|
||||
if (TF_PREDICT_TRUE(_error == nullptr)) return Status::OK(); \
|
||||
return errors::InvalidArgument(_error); \
|
||||
} while (0)
|
||||
|
||||
Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) {
|
||||
TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
|
||||
DataType dtype) {
|
||||
ConverterState state;
|
||||
TF_RETURN_IF_ERROR(InferShapeAndType(obj, &state));
|
||||
Status status = InferShapeAndType(obj, &state);
|
||||
if (!status.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
DataType requested_dtype = DT_INVALID;
|
||||
if (dtype != DT_INVALID) {
|
||||
requested_dtype = dtype;
|
||||
@ -610,116 +616,131 @@ Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) {
|
||||
// we just try instead to create a tensor of the inferred type and
|
||||
// let the caller convert it to the requested type using a cast
|
||||
// operation.
|
||||
const char* error = nullptr;
|
||||
TFE_TensorHandle* handle = nullptr;
|
||||
status = errors::Unimplemented("Missing Python -> Tensor conversion for ",
|
||||
DataTypeString(state.inferred_dtype));
|
||||
switch (requested_dtype) {
|
||||
case DT_FLOAT:
|
||||
if (FloatConverter::Convert(obj, &state, ret) == nullptr)
|
||||
return Status::OK();
|
||||
status = FloatConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_DOUBLE:
|
||||
if (DoubleConverter::Convert(obj, &state, ret) == nullptr)
|
||||
return Status::OK();
|
||||
status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_HALF:
|
||||
if (NumpyHalfConverter::Convert(obj, &state, ret) == nullptr)
|
||||
return Status::OK();
|
||||
status = NumpyHalfConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_INT64:
|
||||
if (Int64Converter::Convert(obj, &state, ret) == nullptr)
|
||||
return Status::OK();
|
||||
status = Int64Converter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_INT32:
|
||||
if (Int32Converter::Convert(obj, &state, ret) == nullptr)
|
||||
return Status::OK();
|
||||
status = Int32Converter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_UINT64:
|
||||
if (UInt64Converter::Convert(obj, &state, ret) == nullptr)
|
||||
return Status::OK();
|
||||
status = UInt64Converter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_COMPLEX128:
|
||||
if (Complex128Converter::Convert(obj, &state, ret) == nullptr)
|
||||
return Status::OK();
|
||||
status = Complex128Converter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_STRING:
|
||||
if (StringConverter::Convert(obj, &state, ret) == nullptr)
|
||||
return Status::OK();
|
||||
status = StringConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_BOOL:
|
||||
if (BoolConverter::Convert(obj, &state, ret) == nullptr)
|
||||
return Status::OK();
|
||||
status = BoolConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
if (status.ok()) return handle;
|
||||
|
||||
switch (state.inferred_dtype) {
|
||||
case DT_FLOAT:
|
||||
// TODO(josh11b): Handle mixed floats and complex numbers?
|
||||
if (requested_dtype == DT_INVALID) {
|
||||
// TensorFlow uses float32s to represent floating point numbers
|
||||
// by default (for space and speed over using doubles).
|
||||
RETURN_STRING_AS_STATUS(FloatConverter::Convert(obj, &state, ret));
|
||||
status = FloatConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
} else {
|
||||
// We are going to do a cast to the user's requested dtype
|
||||
// after this. We use doubles for this intermediate result so
|
||||
// we don't lose precision that might be representable in the
|
||||
// final type.
|
||||
RETURN_STRING_AS_STATUS(DoubleConverter::Convert(obj, &state, ret));
|
||||
status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
}
|
||||
break;
|
||||
|
||||
case DT_DOUBLE:
|
||||
RETURN_STRING_AS_STATUS(DoubleConverter::Convert(obj, &state, ret));
|
||||
status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_HALF:
|
||||
RETURN_STRING_AS_STATUS(NumpyHalfConverter::Convert(obj, &state, ret));
|
||||
status = NumpyHalfConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_INT64:
|
||||
if (requested_dtype == DT_INVALID) {
|
||||
const char* error = Int32Converter::Convert(obj, &state, ret);
|
||||
status = Int32Converter::Convert(ctx, obj, &state, &handle, &error);
|
||||
if (error == ErrorFoundInt64) {
|
||||
error = Int64Converter::Convert(obj, &state, ret);
|
||||
status = Int64Converter::Convert(ctx, obj, &state, &handle, &error);
|
||||
}
|
||||
if (error == ErrorFoundFloat) {
|
||||
error = FloatConverter::Convert(obj, &state, ret);
|
||||
status = FloatConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
}
|
||||
// TODO(josh11b): May also want to fall back to using doubles if
|
||||
// error == ErrorOutOfRange?
|
||||
RETURN_STRING_AS_STATUS(error);
|
||||
} else {
|
||||
const char* error = Int64Converter::Convert(obj, &state, ret);
|
||||
status = Int64Converter::Convert(ctx, obj, &state, &handle, &error);
|
||||
if (error == ErrorFoundFloat) {
|
||||
error = DoubleConverter::Convert(obj, &state, ret);
|
||||
status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
}
|
||||
RETURN_STRING_AS_STATUS(error);
|
||||
}
|
||||
break;
|
||||
|
||||
case DT_STRING:
|
||||
RETURN_STRING_AS_STATUS(StringConverter::Convert(obj, &state, ret));
|
||||
status = StringConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_COMPLEX128:
|
||||
RETURN_STRING_AS_STATUS(Complex128Converter::Convert(obj, &state, ret));
|
||||
status = Complex128Converter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_BOOL:
|
||||
RETURN_STRING_AS_STATUS(BoolConverter::Convert(obj, &state, ret));
|
||||
status = BoolConverter::Convert(ctx, obj, &state, &handle, &error);
|
||||
break;
|
||||
|
||||
case DT_INVALID: // Only occurs for empty tensors.
|
||||
*ret = Tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
|
||||
state.inferred_shape);
|
||||
return Status::OK();
|
||||
{
|
||||
tensorflow::TensorHandle* h = nullptr;
|
||||
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
|
||||
TensorShape(state.inferred_shape));
|
||||
status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
tensor, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &h);
|
||||
if (!status.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle{TensorHandleInterface(h)};
|
||||
}
|
||||
|
||||
default:
|
||||
return errors::Unimplemented("Missing Python -> Tensor conversion for ",
|
||||
DataTypeString(state.inferred_dtype));
|
||||
break;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
if (!status.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
@ -25,12 +26,16 @@ namespace tensorflow {
|
||||
|
||||
// Converts Python object `obj` representing a rectangular array of
|
||||
// Python values (a scalar, a sequence of scalars, a sequence of
|
||||
// sequences, etc.) into a C++ TensorFlow Tensor and stores it in
|
||||
// *ret. If dtype is not None it should by a Python integer
|
||||
// sequences, etc.) into a TFE_TensorHandle.
|
||||
// If dtype is not None it should by a Python integer
|
||||
// representing the desired dtype of the resulting Tensor.
|
||||
// This is used only as a hint, *ret may not have that dtype on
|
||||
// success and may require a cast.
|
||||
Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret);
|
||||
//
|
||||
// If an error occurs, this return nullptr and sets the python error indicator
|
||||
// with PyErr_SetString.
|
||||
TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
|
||||
DataType dtype);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user