Use Status directly in TF_TensorFromTensor
PiperOrigin-RevId: 288700381 Change-Id: I159ea1c87ee3ca4f10db80b540bb7aedf5a7a967
This commit is contained in:
parent
3f9dd57093
commit
f7239df1a3
tensorflow
c
compiler/mlir/tensorflow/utils
python/lib/core
@ -458,7 +458,7 @@ static void TF_Run_Helper(
|
||||
EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
|
||||
continue;
|
||||
}
|
||||
c_outputs[i] = TF_TensorFromTensor(src, status);
|
||||
c_outputs[i] = TF_TensorFromTensor(src, &status->status);
|
||||
if (!status->status.ok()) return;
|
||||
}
|
||||
}
|
||||
@ -1493,7 +1493,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
|
||||
Tensor t;
|
||||
status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
|
||||
if (!status->status.ok()) return;
|
||||
*value = TF_TensorFromTensor(t, status);
|
||||
*value = TF_TensorFromTensor(t, &status->status);
|
||||
}
|
||||
|
||||
void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||
@ -1504,7 +1504,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
|
||||
if (!status->status.ok()) return;
|
||||
const auto len = std::min(max_values, static_cast<int>(ts.size()));
|
||||
for (int i = 0; i < len; ++i) {
|
||||
values[i] = TF_TensorFromTensor(ts[i], status);
|
||||
values[i] = TF_TensorFromTensor(ts[i], &status->status);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2398,7 +2398,7 @@ unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
|
||||
graph->graph.versions().producer(), &evaluated, &result_tensor);
|
||||
if (evaluated) {
|
||||
DCHECK(status->status.ok());
|
||||
*result = TF_TensorFromTensor(result_tensor, status);
|
||||
*result = TF_TensorFromTensor(result_tensor, &status->status);
|
||||
if (!status->status.ok()) evaluated = false;
|
||||
}
|
||||
return evaluated;
|
||||
|
@ -634,7 +634,7 @@ TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
|
||||
std::unique_ptr<tensorflow::Tensor> tensor;
|
||||
reader->GetTensor(name, &tensor, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::TF_TensorFromTensor(*tensor, status);
|
||||
return tensorflow::TF_TensorFromTensor(*tensor, &status->status);
|
||||
}
|
||||
|
||||
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
|
||||
|
@ -188,7 +188,7 @@ namespace tensorflow {
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||
|
||||
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
|
||||
TF_Buffer* out);
|
||||
|
@ -51,7 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
|
||||
namespace tensorflow {
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
|
||||
TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status);
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
|
||||
namespace {
|
||||
@ -227,7 +227,7 @@ TEST(CAPI, LibraryLoadFunctions) {
|
||||
|
||||
void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
const tensorflow::int64 n = data.size();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
Status status;
|
||||
for (const std::vector<tensorflow::int64>& dims :
|
||||
std::vector<std::vector<tensorflow::int64>>{
|
||||
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
|
||||
@ -236,8 +236,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
|
||||
src.flat<tstring>()(i) = data[i];
|
||||
}
|
||||
TF_Tensor* dst = TF_TensorFromTensor(src, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_Tensor* dst = TF_TensorFromTensor(src, &status);
|
||||
ASSERT_TRUE(status.ok()) << status.error_message();
|
||||
|
||||
// Convert back to a C++ Tensor and ensure we get expected output.
|
||||
Tensor output;
|
||||
@ -249,7 +249,6 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
|
||||
|
||||
TF_DeleteTensor(dst);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorEncodeDecodeStrings) {
|
||||
@ -1394,8 +1393,9 @@ TEST(CAPI, SavedModel) {
|
||||
TF_Operation* input_op =
|
||||
TF_GraphOperationByName(graph, input_op_name.c_str());
|
||||
ASSERT_TRUE(input_op != nullptr);
|
||||
csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
Status status;
|
||||
csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}});
|
||||
ASSERT_TRUE(status.ok()) << status.error_message();
|
||||
|
||||
const tensorflow::string output_op_name(
|
||||
tensorflow::ParseTensorName(output_name).first);
|
||||
@ -2522,12 +2522,11 @@ TEST(CAPI, TestTensorIsNotAligned) {
|
||||
|
||||
// Take an unaligned slice.
|
||||
Tensor y = x.Slice(1, 13);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* a = TF_TensorFromTensor(y, status);
|
||||
Status status;
|
||||
TF_Tensor* a = TF_TensorFromTensor(y, &status);
|
||||
if (EIGEN_MAX_ALIGN_BYTES > 0) {
|
||||
EXPECT_FALSE(TF_TensorIsAligned(a));
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
TF_DeleteTensor(a);
|
||||
}
|
||||
|
||||
|
@ -992,7 +992,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
h_cpu->Unref();
|
||||
return nullptr;
|
||||
}
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
|
||||
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, &status->status);
|
||||
h_cpu->Unref();
|
||||
return retval;
|
||||
} else {
|
||||
@ -1008,7 +1008,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
}
|
||||
return tensorflow::TF_TensorFromTensor(tensor, status);
|
||||
return tensorflow::TF_TensorFromTensor(tensor, &status->status);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -181,7 +181,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
|
||||
return;
|
||||
}
|
||||
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
|
||||
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
|
||||
TF_Tensor* result =
|
||||
::tensorflow::TF_TensorFromTensor(cc_tensor, &status->status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
*tensor = result;
|
||||
}
|
||||
|
@ -170,6 +170,11 @@ void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
void StringEncode(const char* src, size_t src_len, char* dst) {
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
memcpy(dst, src, src_len);
|
||||
}
|
||||
|
||||
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||
size_t dst_len, TF_Status* status) {
|
||||
const size_t sz = TF_StringEncodedSize(src_len);
|
||||
@ -185,8 +190,7 @@ size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
|
||||
src_len, "-byte string"));
|
||||
return 0;
|
||||
}
|
||||
dst = tensorflow::core::EncodeVarint64(dst, src_len);
|
||||
memcpy(dst, src, src_len);
|
||||
StringEncode(src, src_len, dst);
|
||||
return sz;
|
||||
}
|
||||
|
||||
@ -245,13 +249,11 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype,
|
||||
namespace tensorflow {
|
||||
|
||||
// Non-static for testing.
|
||||
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
|
||||
*status = tensorflow::Status::OK();
|
||||
if (!src.IsInitialized()) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, FailedPrecondition(
|
||||
"attempt to use a tensor with an uninitialized value"));
|
||||
*status = FailedPrecondition(
|
||||
"attempt to use a tensor with an uninitialized value");
|
||||
return nullptr;
|
||||
}
|
||||
if (src.NumElements() == 0) {
|
||||
@ -259,14 +261,13 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
}
|
||||
if (src.dtype() == tensorflow::DT_RESOURCE) {
|
||||
if (src.shape().dims() != 0) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument(
|
||||
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
||||
src.shape().DebugString(),
|
||||
"). Please file a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues/new, "
|
||||
"ideally with a "
|
||||
"short code snippet that reproduces this error."));
|
||||
*status = InvalidArgument(
|
||||
"Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
|
||||
src.shape().DebugString(),
|
||||
"). Please file a bug at "
|
||||
"https://github.com/tensorflow/tensorflow/issues/new, "
|
||||
"ideally with a "
|
||||
"short code snippet that reproduces this error.");
|
||||
return nullptr;
|
||||
}
|
||||
const string str =
|
||||
@ -305,23 +306,15 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
*offsets = (dst - data_start);
|
||||
offsets++;
|
||||
const string& s = srcarray(i);
|
||||
size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
Set_TF_Status_from_Status(
|
||||
status,
|
||||
InvalidArgument("invalid string tensor encoding (string #", i, " of ",
|
||||
srcarray.size(), "): ", TF_Message(status)));
|
||||
delete[] base;
|
||||
return nullptr;
|
||||
}
|
||||
const size_t consumed = TF_StringEncodedSize(s.size());
|
||||
StringEncode(s.data(), s.size(), dst);
|
||||
dst += consumed;
|
||||
dst_len -= consumed;
|
||||
}
|
||||
if (dst != base + size) {
|
||||
Set_TF_Status_from_Status(
|
||||
status, InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes"));
|
||||
*status = InvalidArgument(
|
||||
"invalid string tensor encoding (decoded ", (dst - base),
|
||||
" bytes, but the tensor is encoded in ", size, " bytes");
|
||||
delete[] base;
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ mlir::LogicalResult EvaluateOperation(
|
||||
for (const auto operand : operands) {
|
||||
Tensor tensor;
|
||||
RETURN_FAILURE_IF_ERROR(ConvertToTensor(operand, &tensor));
|
||||
TF_Tensor* tf_tensor = TF_TensorFromTensor(tensor, status);
|
||||
TF_Tensor* tf_tensor = TF_TensorFromTensor(tensor, &status->status);
|
||||
RETURN_FAILURE_IF_ERROR(status);
|
||||
auto clean_tensor =
|
||||
MakeCleanup([tf_tensor] { TF_DeleteTensor(tf_tensor); });
|
||||
|
@ -539,8 +539,7 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
|
||||
}
|
||||
|
||||
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
||||
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
|
||||
TF_Status* status);
|
||||
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status);
|
||||
|
||||
Status NdarrayToTensor(PyObject* obj, Tensor* ret) {
|
||||
Safe_TF_TensorPtr tf_tensor = make_safe(static_cast<TF_Tensor*>(nullptr));
|
||||
@ -552,12 +551,10 @@ Status NdarrayToTensor(PyObject* obj, Tensor* ret) {
|
||||
}
|
||||
|
||||
Status TensorToNdarray(const Tensor& t, PyObject** ret) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
Safe_TF_TensorPtr tf_tensor = make_safe(TF_TensorFromTensor(t, status));
|
||||
Status tf_status = StatusFromTF_Status(status);
|
||||
TF_DeleteStatus(status);
|
||||
if (!tf_status.ok()) {
|
||||
return tf_status;
|
||||
Status status;
|
||||
Safe_TF_TensorPtr tf_tensor = make_safe(TF_TensorFromTensor(t, &status));
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
return TF_TensorToPyArray(std::move(tf_tensor), ret);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user