diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index 88d91a85533..e38b02d447d 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" @@ -382,6 +383,18 @@ bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) { return TfLiteIntArrayEqual(input1->dims, input2->dims); } +std::string GetShapeDebugString(const TfLiteIntArray* shape) { + std::string str; + for (int d = 0; d < shape->size; ++d) { + if (str.empty()) + str = "[" + std::to_string(shape->data[d]); + else + str += ", " + std::to_string(shape->data[d]); + } + str += "]"; + return str; +} + // TODO(petewarden): Having macros around this is ugly, look at other strategies // before replicating this approach elsewhere. #ifndef TF_LITE_STATIC_MEMORY @@ -401,7 +414,13 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, for (int i = 0; i < out_dims; ++i) { int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1); int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1); - TF_LITE_ENSURE(context, d1 == d2 || d1 == 1 || d2 == 1); + if (!(d1 == d2 || d1 == 1 || d2 == 1)) { + context->ReportError(context, + "Given shapes, %s and %s, are not broadcastable.", + GetShapeDebugString(input1->dims).c_str(), + GetShapeDebugString(input2->dims).c_str()); + return kTfLiteError; + } shape->data[out_dims - i - 1] = std::max(d1, d2); } *output_shape = shape.release(); @@ -424,9 +443,15 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1); int d3 = i >= dims3 ? 1 : SizeOfDimension(input3, dims3 - i - 1); int max_value = std::max(std::max(d1, d2), d3); - TF_LITE_ENSURE(context, d1 == 1 || d1 == max_value); - TF_LITE_ENSURE(context, d2 == 1 || d2 == max_value); - TF_LITE_ENSURE(context, d3 == 1 || d3 == max_value); + if (!(d1 == 1 || d1 == max_value) || !(d2 == 1 || d2 == max_value) || + !(d3 == 1 || d3 == max_value)) { + context->ReportError( + context, "Given shapes, %s, %s and %s, are not broadcastable.", + GetShapeDebugString(input1->dims).c_str(), + GetShapeDebugString(input2->dims).c_str(), + GetShapeDebugString(input3->dims).c_str()); + return kTfLiteError; + } shape->data[out_dims - i - 1] = max_value; } *output_shape = shape.release(); diff --git a/tensorflow/lite/kernels/kernel_util_test.cc b/tensorflow/lite/kernels/kernel_util_test.cc index db0cc3cb39c..67e6dce3e3b 100644 --- a/tensorflow/lite/kernels/kernel_util_test.cc +++ b/tensorflow/lite/kernels/kernel_util_test.cc @@ -31,7 +31,22 @@ limitations under the License. namespace tflite { namespace { -void ReportError(TfLiteContext* context, const char* format, ...) {} +struct TestContext : public TfLiteContext { + string error; +}; + +void ReportError(TfLiteContext* context, const char* format, ...) { + TestContext* c = static_cast(context); + const size_t kBufferSize = 1024; + char temp_buffer[kBufferSize]; + + va_list args; + va_start(args, format); + vsnprintf(temp_buffer, kBufferSize, format, args); + va_end(args); + + c->error = temp_buffer; +} class KernelUtilTest : public ::testing::Test { public: @@ -73,7 +88,7 @@ class KernelUtilTest : public ::testing::Test { } protected: - TfLiteContext context_; + TestContext context_; TfLiteTensor tensor1_; TfLiteTensor tensor2_; TfLiteTensor tensor3_; @@ -108,6 +123,8 @@ TEST_F(KernelUtilTest, BroadcastShapeIncompatibleDim) { EXPECT_NE(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, &tensor2_, &output)); EXPECT_EQ(output, nullptr); + EXPECT_EQ(context_.error, + "Given shapes, [1, 2] and [1, 3], are not broadcastable."); } TEST_F(KernelUtilTest, BroadcastShapeOnes) { @@ -168,6 +185,8 @@ TEST_F(KernelUtilTest, BroadcastShapeIncompatibleDimOnThreeTensors) { CalculateShapeForBroadcast(&context_, &tensor1_, &tensor2_, &tensor3_, &output)); EXPECT_EQ(output, nullptr); + EXPECT_EQ(context_.error, + "Given shapes, [1, 2], [1, 3] and [1, 4], are not broadcastable."); } TEST_F(KernelUtilTest, BroadcastShapeOnesOnThreeTensors) {