Improve error messages on incompatible broadcastable dimensions

Users are often confused with the original error message. This change makes
debug the problem easily.

Related github issue #46686

PiperOrigin-RevId: 353934118
Change-Id: I147c7f76fe1cc3c6b1a77ae77ba0e379db150948
This commit is contained in:
Jaesung Chung 2021-01-26 13:15:09 -08:00 committed by TensorFlower Gardener
parent 75c3c9e16e
commit 9ee7896d22
2 changed files with 50 additions and 6 deletions

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <complex>
#include <limits>
#include <memory>
#include <string>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
@ -382,6 +383,18 @@ bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) {
return TfLiteIntArrayEqual(input1->dims, input2->dims);
}
std::string GetShapeDebugString(const TfLiteIntArray* shape) {
std::string str;
for (int d = 0; d < shape->size; ++d) {
if (str.empty())
str = "[" + std::to_string(shape->data[d]);
else
str += ", " + std::to_string(shape->data[d]);
}
str += "]";
return str;
}
// TODO(petewarden): Having macros around this is ugly, look at other strategies
// before replicating this approach elsewhere.
#ifndef TF_LITE_STATIC_MEMORY
@ -401,7 +414,13 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
for (int i = 0; i < out_dims; ++i) {
int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
TF_LITE_ENSURE(context, d1 == d2 || d1 == 1 || d2 == 1);
if (!(d1 == d2 || d1 == 1 || d2 == 1)) {
context->ReportError(context,
"Given shapes, %s and %s, are not broadcastable.",
GetShapeDebugString(input1->dims).c_str(),
GetShapeDebugString(input2->dims).c_str());
return kTfLiteError;
}
shape->data[out_dims - i - 1] = std::max(d1, d2);
}
*output_shape = shape.release();
@ -424,9 +443,15 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
int d3 = i >= dims3 ? 1 : SizeOfDimension(input3, dims3 - i - 1);
int max_value = std::max(std::max(d1, d2), d3);
TF_LITE_ENSURE(context, d1 == 1 || d1 == max_value);
TF_LITE_ENSURE(context, d2 == 1 || d2 == max_value);
TF_LITE_ENSURE(context, d3 == 1 || d3 == max_value);
if (!(d1 == 1 || d1 == max_value) || !(d2 == 1 || d2 == max_value) ||
!(d3 == 1 || d3 == max_value)) {
context->ReportError(
context, "Given shapes, %s, %s and %s, are not broadcastable.",
GetShapeDebugString(input1->dims).c_str(),
GetShapeDebugString(input2->dims).c_str(),
GetShapeDebugString(input3->dims).c_str());
return kTfLiteError;
}
shape->data[out_dims - i - 1] = max_value;
}
*output_shape = shape.release();

View File

@ -31,7 +31,22 @@ limitations under the License.
namespace tflite {
namespace {
void ReportError(TfLiteContext* context, const char* format, ...) {}
struct TestContext : public TfLiteContext {
string error;
};
void ReportError(TfLiteContext* context, const char* format, ...) {
TestContext* c = static_cast<TestContext*>(context);
const size_t kBufferSize = 1024;
char temp_buffer[kBufferSize];
va_list args;
va_start(args, format);
vsnprintf(temp_buffer, kBufferSize, format, args);
va_end(args);
c->error = temp_buffer;
}
class KernelUtilTest : public ::testing::Test {
public:
@ -73,7 +88,7 @@ class KernelUtilTest : public ::testing::Test {
}
protected:
TfLiteContext context_;
TestContext context_;
TfLiteTensor tensor1_;
TfLiteTensor tensor2_;
TfLiteTensor tensor3_;
@ -108,6 +123,8 @@ TEST_F(KernelUtilTest, BroadcastShapeIncompatibleDim) {
EXPECT_NE(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_,
&tensor2_, &output));
EXPECT_EQ(output, nullptr);
EXPECT_EQ(context_.error,
"Given shapes, [1, 2] and [1, 3], are not broadcastable.");
}
TEST_F(KernelUtilTest, BroadcastShapeOnes) {
@ -168,6 +185,8 @@ TEST_F(KernelUtilTest, BroadcastShapeIncompatibleDimOnThreeTensors) {
CalculateShapeForBroadcast(&context_, &tensor1_, &tensor2_,
&tensor3_, &output));
EXPECT_EQ(output, nullptr);
EXPECT_EQ(context_.error,
"Given shapes, [1, 2], [1, 3] and [1, 4], are not broadcastable.");
}
TEST_F(KernelUtilTest, BroadcastShapeOnesOnThreeTensors) {