Improve error messages on incompatible broadcastable dimensions
Users are often confused with the original error message. This change makes debug the problem easily. Related github issue #46686 PiperOrigin-RevId: 353934118 Change-Id: I147c7f76fe1cc3c6b1a77ae77ba0e379db150948
This commit is contained in:
parent
75c3c9e16e
commit
9ee7896d22
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <complex>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
@ -382,6 +383,18 @@ bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) {
|
||||
return TfLiteIntArrayEqual(input1->dims, input2->dims);
|
||||
}
|
||||
|
||||
std::string GetShapeDebugString(const TfLiteIntArray* shape) {
|
||||
std::string str;
|
||||
for (int d = 0; d < shape->size; ++d) {
|
||||
if (str.empty())
|
||||
str = "[" + std::to_string(shape->data[d]);
|
||||
else
|
||||
str += ", " + std::to_string(shape->data[d]);
|
||||
}
|
||||
str += "]";
|
||||
return str;
|
||||
}
|
||||
|
||||
// TODO(petewarden): Having macros around this is ugly, look at other strategies
|
||||
// before replicating this approach elsewhere.
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
@ -401,7 +414,13 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
||||
for (int i = 0; i < out_dims; ++i) {
|
||||
int d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1);
|
||||
int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
|
||||
TF_LITE_ENSURE(context, d1 == d2 || d1 == 1 || d2 == 1);
|
||||
if (!(d1 == d2 || d1 == 1 || d2 == 1)) {
|
||||
context->ReportError(context,
|
||||
"Given shapes, %s and %s, are not broadcastable.",
|
||||
GetShapeDebugString(input1->dims).c_str(),
|
||||
GetShapeDebugString(input2->dims).c_str());
|
||||
return kTfLiteError;
|
||||
}
|
||||
shape->data[out_dims - i - 1] = std::max(d1, d2);
|
||||
}
|
||||
*output_shape = shape.release();
|
||||
@ -424,9 +443,15 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
|
||||
int d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1);
|
||||
int d3 = i >= dims3 ? 1 : SizeOfDimension(input3, dims3 - i - 1);
|
||||
int max_value = std::max(std::max(d1, d2), d3);
|
||||
TF_LITE_ENSURE(context, d1 == 1 || d1 == max_value);
|
||||
TF_LITE_ENSURE(context, d2 == 1 || d2 == max_value);
|
||||
TF_LITE_ENSURE(context, d3 == 1 || d3 == max_value);
|
||||
if (!(d1 == 1 || d1 == max_value) || !(d2 == 1 || d2 == max_value) ||
|
||||
!(d3 == 1 || d3 == max_value)) {
|
||||
context->ReportError(
|
||||
context, "Given shapes, %s, %s and %s, are not broadcastable.",
|
||||
GetShapeDebugString(input1->dims).c_str(),
|
||||
GetShapeDebugString(input2->dims).c_str(),
|
||||
GetShapeDebugString(input3->dims).c_str());
|
||||
return kTfLiteError;
|
||||
}
|
||||
shape->data[out_dims - i - 1] = max_value;
|
||||
}
|
||||
*output_shape = shape.release();
|
||||
|
@ -31,7 +31,22 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
void ReportError(TfLiteContext* context, const char* format, ...) {}
|
||||
struct TestContext : public TfLiteContext {
|
||||
string error;
|
||||
};
|
||||
|
||||
void ReportError(TfLiteContext* context, const char* format, ...) {
|
||||
TestContext* c = static_cast<TestContext*>(context);
|
||||
const size_t kBufferSize = 1024;
|
||||
char temp_buffer[kBufferSize];
|
||||
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
vsnprintf(temp_buffer, kBufferSize, format, args);
|
||||
va_end(args);
|
||||
|
||||
c->error = temp_buffer;
|
||||
}
|
||||
|
||||
class KernelUtilTest : public ::testing::Test {
|
||||
public:
|
||||
@ -73,7 +88,7 @@ class KernelUtilTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
protected:
|
||||
TfLiteContext context_;
|
||||
TestContext context_;
|
||||
TfLiteTensor tensor1_;
|
||||
TfLiteTensor tensor2_;
|
||||
TfLiteTensor tensor3_;
|
||||
@ -108,6 +123,8 @@ TEST_F(KernelUtilTest, BroadcastShapeIncompatibleDim) {
|
||||
EXPECT_NE(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_,
|
||||
&tensor2_, &output));
|
||||
EXPECT_EQ(output, nullptr);
|
||||
EXPECT_EQ(context_.error,
|
||||
"Given shapes, [1, 2] and [1, 3], are not broadcastable.");
|
||||
}
|
||||
|
||||
TEST_F(KernelUtilTest, BroadcastShapeOnes) {
|
||||
@ -168,6 +185,8 @@ TEST_F(KernelUtilTest, BroadcastShapeIncompatibleDimOnThreeTensors) {
|
||||
CalculateShapeForBroadcast(&context_, &tensor1_, &tensor2_,
|
||||
&tensor3_, &output));
|
||||
EXPECT_EQ(output, nullptr);
|
||||
EXPECT_EQ(context_.error,
|
||||
"Given shapes, [1, 2], [1, 3] and [1, 4], are not broadcastable.");
|
||||
}
|
||||
|
||||
TEST_F(KernelUtilTest, BroadcastShapeOnesOnThreeTensors) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user