diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 570e84ecc24..49d61d7f3c6 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -488,7 +488,13 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { } // Non-static for testing. -TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src) { +TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, + TF_Status* status) { + if (!src.IsInitialized()) { + status->status = FailedPrecondition( + "attempt to use a tensor with an uninitialized value"); + return nullptr; + } if (src.dtype() == DT_RESOURCE) { DCHECK_EQ(0, src.shape().dims()) << src.shape().DebugString(); if (src.shape().dims() != 0) { @@ -528,18 +534,26 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src) { char* dst = data_start; // Where next string is encoded. size_t dst_len = size - static_cast<size_t>(data_start - base); tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base); - TF_Status status; for (int i = 0; i < srcarray.size(); ++i) { *offsets = (dst - data_start); offsets++; const tensorflow::string& s = srcarray(i); - size_t consumed = - TF_StringEncode(s.data(), s.size(), dst, dst_len, &status); - CHECK(status.status.ok()); + size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); + if (!status->status.ok()) { + status->status = InvalidArgument( + "invalid string tensor encoding (string #", i, " of ", + srcarray.size(), "): ", status->status.error_message()); + return nullptr; + } dst += consumed; dst_len -= consumed; } - CHECK_EQ(dst, base + size); + if (dst != base + size) { + status->status = InvalidArgument( + "invalid string tensor encoding (decoded ", (dst - base), + " bytes, but the tensor is encoded in ", size, " bytes"); + return nullptr; + } auto dims = src.shape().dim_sizes(); std::vector<tensorflow::int64> dimvec(dims.size()); @@ -650,7 +664,8 @@ static void TF_Run_Helper( static_cast<TF_DataType>(src.dtype()), src.shape()); continue; } - c_outputs[i] = TF_TensorFromTensor(src); + c_outputs[i] = TF_TensorFromTensor(src, status); + if (!status->status.ok()) return; } } @@ -1605,7 +1620,7 @@ void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name, Tensor t; status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t); if (!status->status.ok()) return; - *value = TF_TensorFromTensor(t); + *value = TF_TensorFromTensor(t, status); } void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, @@ -1616,7 +1631,7 @@ void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name, if (!status->status.ok()) return; const auto len = std::min(max_values, static_cast<int>(ts.size())); for (int i = 0; i < len; ++i) { - values[i] = TF_TensorFromTensor(ts[i]); + values[i] = TF_TensorFromTensor(ts[i], status); } } diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 25b6cbd8e7a..1d191fc36d4 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -45,7 +45,7 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { -TF_Tensor* TF_TensorFromTensor(const Tensor& src); +TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { @@ -137,6 +137,7 @@ TEST(CAPI, LibraryLoadFunctions) { void TestEncodeDecode(int line, const std::vector<string>& data) { const tensorflow::int64 n = data.size(); + TF_Status* status = TF_NewStatus(); for (const std::vector<tensorflow::int64>& dims : std::vector<std::vector<tensorflow::int64>>{ {n}, {1, n}, {n, 1}, {n / 2, 2}}) { @@ -145,7 +146,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) { for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { src.flat<string>()(i) = data[i]; } - TF_Tensor* dst = TF_TensorFromTensor(src); + TF_Tensor* dst = TF_TensorFromTensor(src, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // Convert back to a C++ Tensor and ensure we get expected output. Tensor output; @@ -157,6 +159,7 @@ void TestEncodeDecode(int line, const std::vector<string>& data) { TF_DeleteTensor(dst); } + TF_DeleteStatus(status); } TEST(CAPI, TensorEncodeDecodeStrings) { @@ -914,7 +917,8 @@ TEST(CAPI, SavedModel) { TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); - csession.SetInputs({{input_op, TF_TensorFromTensor(input)}}); + csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}}); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); const tensorflow::string output_op_name = tensorflow::ParseTensorName(output_name).first.ToString();