Safe dynamic reshape option (only copy when the new_shape is not a nullptr)

PiperOrigin-RevId: 288254073
Change-Id: I68725492420f9c3a344b32f1c21d4762b9b47382
This commit is contained in:
Renjie Liu 2020-01-06 00:30:44 -08:00 committed by TensorFlower Gardener
parent dec91cdf18
commit 5c5a144afc
2 changed files with 12 additions and 10 deletions

View File

@ -499,10 +499,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
auto params = safe_allocator.Allocate<TfLiteReshapeParams>();
if (const auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
auto* new_shape = schema_params->new_shape();
TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray(
sizeof(params->shape), new_shape, params->shape, error_reporter,
"reshape"));
params->num_dimensions = new_shape->size();
// TODO(b/147203660): We need to figure out when dynamic reshape
// (new_shape is a tensor) happens, why the option is not a nullptr.
// But nonethless, we should only copy when new_shape is not a nullptr.
if (new_shape) {
TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray(
sizeof(params->shape), new_shape, params->shape, error_reporter,
"reshape"));
params->num_dimensions = new_shape->size();
}
}
*builtin_data = reinterpret_cast<void*>(params.release());
break;

View File

@ -93,15 +93,12 @@ TEST_F(FlatbufferConversionsTest, ParseBadSqueeze) {
"Input array not provided for operation 'squeeze'"));
}
TEST_F(FlatbufferConversionsTest, ParseBadReshape) {
TEST_F(FlatbufferConversionsTest, ParseDynamicReshape) {
const Operator* op = BuildTestOperator(
BuiltinOptions_ReshapeOptions, CreateSqueezeOptions(builder_).Union());
BuiltinOptions_ReshapeOptions, CreateReshapeOptions(builder_).Union());
void* output_data = nullptr;
EXPECT_NE(kTfLiteOk, ParseOpData(op, BuiltinOperator_RESHAPE, &mock_reporter_,
EXPECT_EQ(kTfLiteOk, ParseOpData(op, BuiltinOperator_RESHAPE, &mock_reporter_,
&mock_allocator_, &output_data));
EXPECT_THAT(mock_reporter_.GetAsString(),
::testing::ContainsRegex(
"Input array not provided for operation 'reshape'"));
}
TEST_F(FlatbufferConversionsTest, TestParseOpDataConv) {