diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 8306e5c1db0..fe2f98afd00 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -825,6 +825,7 @@ cc_library( ":numpy_lib", ":py_util", ":safe_ptr", + "//tensorflow/c/eager:c_api_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/python_runtime:headers", diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index e6c8e9b32e5..bd938b658e8 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -252,25 +252,6 @@ TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle, #undef RETURN_ERROR } -TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* value, - DataType dtype) { - tensorflow::TensorHandle* handle = nullptr; - tensorflow::Tensor t; - // TODO(josh11b): Have PySeqToTensor set python errors instead of - // returning Status. - auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t); - if (cppstatus.ok()) { - cppstatus = tensorflow::TensorHandle::CreateLocalHandle( - t, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle); - } - if (!cppstatus.ok()) { - PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str()); - return nullptr; - } - CHECK_NE(handle, nullptr); - return new TFE_TensorHandle{tensorflow::TensorHandleInterface(handle)}; -} - TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx, PyObject* value, tensorflow::DataType dtype, diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 5d4916f48fc..89aa44ea298 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/python/lib/core/py_seq_tensor.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -67,7 +68,7 @@ bool IsPyFloat(PyObject* obj) { struct ConverterState { // The inferred tensor shape. - TensorShape inferred_shape; + gtl::InlinedVector inferred_shape; // The inferred tensor data type. DataType inferred_dtype; @@ -155,14 +156,14 @@ Status InferShapeAndType(PyObject* obj, ConverterState* state) { } else if (PySequence_Check(obj)) { auto length = PySequence_Length(obj); if (length > 0) { - state->inferred_shape.AddDim(length); + state->inferred_shape.push_back(length); PyObject* elem = nullptr; TF_RETURN_IF_ERROR(SampleElementFromSequence(obj, &elem)); obj = elem; refs_to_clean.push_back(make_safe(obj)); continue; } else if (length == 0) { - state->inferred_shape.AddDim(length); + state->inferred_shape.push_back(length); state->inferred_dtype = DT_INVALID; // Invalid dtype for empty tensors. } else { // The sequence does not have a valid length (PySequence_Length < 0). @@ -247,12 +248,12 @@ struct Converter { Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular; - const int64 s = state->inferred_shape.dim_size(depth); + const int64 s = state->inferred_shape[depth]; if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { return ErrorRectangular; } - if (state->inferred_shape.dims() - depth > 1) { + if (state->inferred_shape.size() - depth > 1) { /* Iterate over outer dim, and recursively convert each element. */ for (int64 i = 0; i < s; ++i) { const char* error = Helper(PySequence_Fast_GET_ITEM(seq.get(), i), @@ -272,24 +273,31 @@ struct Converter { return nullptr; } - static const char* Convert(PyObject* obj, ConverterState* state, - Tensor* dest) { + static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state, + TFE_TensorHandle** h, const char** error) { /* TODO(josh11b): Allocator & attributes? */ - Tensor result(ConverterTraits::kTypeEnum, state->inferred_shape); - if (state->inferred_shape.dims() == 0) { /* Scalar case */ + Tensor result(ConverterTraits::kTypeEnum, + TensorShape(state->inferred_shape)); + if (state->inferred_shape.empty()) { /* Scalar case */ T value; auto scalar = ZeroDimArrayToScalar(obj, state); - const char* error = ConverterTraits::ConvertScalar(scalar, &value); + *error = ConverterTraits::ConvertScalar(scalar, &value); Py_DECREF(scalar); - if (error != nullptr) return error; + if (*error != nullptr) return errors::InvalidArgument(*error); result.scalar()() = value; } else { T* buf = result.flat().data(); - const char* error = Helper(obj, 0, state, &buf); - if (error != nullptr) return error; + *error = Helper(obj, 0, state, &buf); + if (*error != nullptr) return errors::InvalidArgument(*error); } - *dest = result; - return nullptr; + tensorflow::TensorHandle* handle = nullptr; + auto status = tensorflow::TensorHandle::CreateLocalHandle( + result, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle); + if (!status.ok()) { + return status; + } + *h = new TFE_TensorHandle{TensorHandleInterface(handle)}; + return Status::OK(); } }; @@ -592,16 +600,14 @@ typedef Converter BoolConverter; } // namespace -#define RETURN_STRING_AS_STATUS(...) \ - do { \ - const char* _error = (__VA_ARGS__); \ - if (TF_PREDICT_TRUE(_error == nullptr)) return Status::OK(); \ - return errors::InvalidArgument(_error); \ - } while (0) - -Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) { +TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, + DataType dtype) { ConverterState state; - TF_RETURN_IF_ERROR(InferShapeAndType(obj, &state)); + Status status = InferShapeAndType(obj, &state); + if (!status.ok()) { + PyErr_SetString(PyExc_ValueError, status.error_message().c_str()); + return nullptr; + } DataType requested_dtype = DT_INVALID; if (dtype != DT_INVALID) { requested_dtype = dtype; @@ -610,116 +616,131 @@ Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) { // we just try instead to create a tensor of the inferred type and // let the caller convert it to the requested type using a cast // operation. + const char* error = nullptr; + TFE_TensorHandle* handle = nullptr; + status = errors::Unimplemented("Missing Python -> Tensor conversion for ", + DataTypeString(state.inferred_dtype)); switch (requested_dtype) { case DT_FLOAT: - if (FloatConverter::Convert(obj, &state, ret) == nullptr) - return Status::OK(); + status = FloatConverter::Convert(ctx, obj, &state, &handle, &error); break; case DT_DOUBLE: - if (DoubleConverter::Convert(obj, &state, ret) == nullptr) - return Status::OK(); + status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error); break; case DT_HALF: - if (NumpyHalfConverter::Convert(obj, &state, ret) == nullptr) - return Status::OK(); + status = NumpyHalfConverter::Convert(ctx, obj, &state, &handle, &error); break; case DT_INT64: - if (Int64Converter::Convert(obj, &state, ret) == nullptr) - return Status::OK(); + status = Int64Converter::Convert(ctx, obj, &state, &handle, &error); break; case DT_INT32: - if (Int32Converter::Convert(obj, &state, ret) == nullptr) - return Status::OK(); + status = Int32Converter::Convert(ctx, obj, &state, &handle, &error); break; case DT_UINT64: - if (UInt64Converter::Convert(obj, &state, ret) == nullptr) - return Status::OK(); + status = UInt64Converter::Convert(ctx, obj, &state, &handle, &error); break; case DT_COMPLEX128: - if (Complex128Converter::Convert(obj, &state, ret) == nullptr) - return Status::OK(); + status = Complex128Converter::Convert(ctx, obj, &state, &handle, &error); break; case DT_STRING: - if (StringConverter::Convert(obj, &state, ret) == nullptr) - return Status::OK(); + status = StringConverter::Convert(ctx, obj, &state, &handle, &error); break; case DT_BOOL: - if (BoolConverter::Convert(obj, &state, ret) == nullptr) - return Status::OK(); + status = BoolConverter::Convert(ctx, obj, &state, &handle, &error); break; default: break; } + if (status.ok()) return handle; + switch (state.inferred_dtype) { case DT_FLOAT: // TODO(josh11b): Handle mixed floats and complex numbers? if (requested_dtype == DT_INVALID) { // TensorFlow uses float32s to represent floating point numbers // by default (for space and speed over using doubles). - RETURN_STRING_AS_STATUS(FloatConverter::Convert(obj, &state, ret)); + status = FloatConverter::Convert(ctx, obj, &state, &handle, &error); } else { // We are going to do a cast to the user's requested dtype // after this. We use doubles for this intermediate result so // we don't lose precision that might be representable in the // final type. - RETURN_STRING_AS_STATUS(DoubleConverter::Convert(obj, &state, ret)); + status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error); } + break; case DT_DOUBLE: - RETURN_STRING_AS_STATUS(DoubleConverter::Convert(obj, &state, ret)); + status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error); + break; case DT_HALF: - RETURN_STRING_AS_STATUS(NumpyHalfConverter::Convert(obj, &state, ret)); + status = NumpyHalfConverter::Convert(ctx, obj, &state, &handle, &error); + break; case DT_INT64: if (requested_dtype == DT_INVALID) { - const char* error = Int32Converter::Convert(obj, &state, ret); + status = Int32Converter::Convert(ctx, obj, &state, &handle, &error); if (error == ErrorFoundInt64) { - error = Int64Converter::Convert(obj, &state, ret); + status = Int64Converter::Convert(ctx, obj, &state, &handle, &error); } if (error == ErrorFoundFloat) { - error = FloatConverter::Convert(obj, &state, ret); + status = FloatConverter::Convert(ctx, obj, &state, &handle, &error); } // TODO(josh11b): May also want to fall back to using doubles if // error == ErrorOutOfRange? - RETURN_STRING_AS_STATUS(error); } else { - const char* error = Int64Converter::Convert(obj, &state, ret); + status = Int64Converter::Convert(ctx, obj, &state, &handle, &error); if (error == ErrorFoundFloat) { - error = DoubleConverter::Convert(obj, &state, ret); + status = DoubleConverter::Convert(ctx, obj, &state, &handle, &error); } - RETURN_STRING_AS_STATUS(error); } + break; case DT_STRING: - RETURN_STRING_AS_STATUS(StringConverter::Convert(obj, &state, ret)); + status = StringConverter::Convert(ctx, obj, &state, &handle, &error); + break; case DT_COMPLEX128: - RETURN_STRING_AS_STATUS(Complex128Converter::Convert(obj, &state, ret)); + status = Complex128Converter::Convert(ctx, obj, &state, &handle, &error); + break; case DT_BOOL: - RETURN_STRING_AS_STATUS(BoolConverter::Convert(obj, &state, ret)); + status = BoolConverter::Convert(ctx, obj, &state, &handle, &error); + break; case DT_INVALID: // Only occurs for empty tensors. - *ret = Tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, - state.inferred_shape); - return Status::OK(); + { + tensorflow::TensorHandle* h = nullptr; + Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, + TensorShape(state.inferred_shape)); + status = tensorflow::TensorHandle::CreateLocalHandle( + tensor, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &h); + if (!status.ok()) { + PyErr_SetString(PyExc_ValueError, status.error_message().c_str()); + return nullptr; + } + return new TFE_TensorHandle{TensorHandleInterface(h)}; + } default: - return errors::Unimplemented("Missing Python -> Tensor conversion for ", - DataTypeString(state.inferred_dtype)); + break; } - return Status::OK(); + if (!status.ok()) { + PyErr_SetString(PyExc_ValueError, status.error_message().c_str()); + return nullptr; + } + + return handle; } } // namespace tensorflow diff --git a/tensorflow/python/lib/core/py_seq_tensor.h b/tensorflow/python/lib/core/py_seq_tensor.h index 25b94a90b16..1c9e2b41f9d 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.h +++ b/tensorflow/python/lib/core/py_seq_tensor.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -25,12 +26,16 @@ namespace tensorflow { // Converts Python object `obj` representing a rectangular array of // Python values (a scalar, a sequence of scalars, a sequence of -// sequences, etc.) into a C++ TensorFlow Tensor and stores it in -// *ret. If dtype is not None it should by a Python integer +// sequences, etc.) into a TFE_TensorHandle. +// If dtype is not None it should by a Python integer // representing the desired dtype of the resulting Tensor. // This is used only as a hint, *ret may not have that dtype on // success and may require a cast. -Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret); +// +// If an error occurs, this return nullptr and sets the python error indicator +// with PyErr_SetString. +TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, + DataType dtype); } // namespace tensorflow