Handle edge case in strided slice: begin equal to dim

If input has dim N, begin value with value N should be accepted.
Also this change fixes reduce op kernels to handle inputs whose shape has zero
dim.

PiperOrigin-RevId: 333054231
Change-Id: Ibfc9728db94bd2c16d2033522df1ded24e0eded4
This commit is contained in:
Jaesung Chung 2020-09-22 04:55:03 -07:00 committed by TensorFlower Gardener
parent d7b9b383e2
commit fb578a809c
7 changed files with 168 additions and 16 deletions

View File

@ -132,6 +132,11 @@ inline bool ReduceGeneric(const T* input_data, const int* input_dims,
bool keep_dims, int* temp_index, int* resolved_axis,
T init_value,
T reducer(const T current, const T in)) {
// Return early when input shape has zero dim.
for (int i = 0; i < input_num_dims; ++i) {
if (input_dims[i] == 0) return true;
}
// Reset output data.
if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value,
output_data)) {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <limits>
#include <vector>
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/types.h"
@ -69,8 +70,8 @@ inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
}
// Return the index for the first element along that axis. This index will be a
// positive integer between [0, axis_size - 1] that can be used to index
// directly into the data.
// positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
// that can be used to index directly into the data.
inline int StartForAxis(const tflite::StridedSliceParams& params,
const RuntimeShape& input_shape, int axis) {
const auto begin_mask = params.begin_mask;
@ -102,7 +103,13 @@ inline int StartForAxis(const tflite::StridedSliceParams& params,
}
// Clamping
start = Clamp(start, 0, axis_size - 1);
if (strides[axis] > 0) {
// Forward iteration
start = Clamp(start, 0, axis_size);
} else {
// Backward iteration
start = Clamp(start, -1, axis_size - 1);
}
return start;
}

View File

@ -295,6 +295,11 @@ TfLiteStatus EvalMeanReferenceOps(TfLiteContext* context,
op_params.axis_count = num_axis;
ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
const TfLiteTensor* input = op_context.input;
// Return early when input shape has zero dim.
for (int i = 0; i < input->dims->size; ++i) {
if (input->dims->data[i] == 0) return kTfLiteOk;
}
// TODO(b/139102329): Handle all the cases in the combined reference
// method.
if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
@ -371,14 +376,19 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
}
// Return early when input shape has zero dim.
const TfLiteTensor* input = op_context.input;
for (int i = 0; i < input->dims->size; ++i) {
if (input->dims->data[i] == 0) return kTfLiteOk;
}
if (kernel_type == kGenericOptimized) {
// Use optimized ops if available.
switch (op_context.input->type) {
switch (input->type) {
case kTfLiteInt8: {
tflite::MeanParams op_params;
op_params.axis_count = num_axis;
ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
const TfLiteTensor* input = op_context.input;
if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
op_params.axis_count == 2 &&
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
@ -398,7 +408,6 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
tflite::MeanParams op_params;
op_params.axis_count = num_axis;
ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
const TfLiteTensor* input = op_context.input;
if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
op_params.axis_count == 2 &&
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
@ -519,22 +528,28 @@ TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
ResizeTempAxis(context, op_context, resolved_axis));
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
}
if (op_context->input->type == kTfLiteUInt8 ||
op_context->input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, op_context->input->params.scale,
const TfLiteTensor* input = op_context->input;
// Return early when input shape has zero dim.
for (int i = 0; i < input->dims->size; ++i) {
if (input->dims->data[i] == 0) return kTfLiteOk;
}
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, input->params.scale,
op_context->output->params.scale);
TF_LITE_ENSURE_EQ(context, op_context->input->params.zero_point,
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
op_context->output->params.zero_point);
}
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<T>(
GetTensorData<T>(op_context->input), op_context->input->dims->data,
op_context->input->dims->size, GetTensorData<T>(op_context->output),
op_context->output->dims->data, op_context->output->dims->size,
GetTensorData<int>(op_context->axis), num_axis,
op_context->params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis), init_value, reducer));
GetTensorData<T>(input), input->dims->data, input->dims->size,
GetTensorData<T>(op_context->output), op_context->output->dims->data,
op_context->output->dims->size, GetTensorData<int>(op_context->axis),
num_axis, op_context->params->keep_dims,
GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
init_value, reducer));
return kTfLiteOk;
}
@ -658,6 +673,11 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
}
// Return early when input shape has zero dim.
for (int i = 0; i < input->dims->size; ++i) {
if (input->dims->data[i] == 0) return kTfLiteOk;
}
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE(
context,

View File

@ -267,6 +267,17 @@ TEST(ConstFloatMeanOpTest, KeepDims) {
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
}
TEST(ConstFloatMeanOpTest, ZeroInputDim) {
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
MeanOpConstModel m({TensorType_FLOAT32, {4, 0, 2}}, {TensorType_FLOAT32, {3}},
{2}, {0, 2}, true);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
}
// Uses a set of reduction conditions that trigger the specialized 4D version
// of Mean.
TEST(ConstFloatMeanOpTest, KeepDims4DMean) {
@ -663,6 +674,16 @@ TEST(ConstFloatSumOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({84, 100, 116})));
}
TEST(ConstFloatSumOpTest, ZeroInputDim) {
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
SumOpConstModel m({TensorType_FLOAT32, {4, 0, 2}}, {TensorType_FLOAT32, {3}},
{2}, {0, 2}, true);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
}
TEST(DynamicFloatSumOpTest, NotKeepDims) {
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@ -842,6 +863,16 @@ TEST(ConstFloatProdOpTest, KeepDims) {
ArrayFloatNear({7.74592e+06, 1.197504e+08, 6.6889152e+08})));
}
TEST(ConstFloatProdOpTest, ZeroInputDim) {
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
ProdOpConstModel m({TensorType_FLOAT32, {4, 0, 2}}, {TensorType_FLOAT32, {3}},
{2}, {0, 2}, true);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
}
TEST(DynamicFloatProdOpTest, NotKeepDims) {
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@ -915,6 +946,16 @@ TEST(ConstFloatMaxOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({20, 22, 24})));
}
TEST(ConstFloatMaxOpTest, ZeroInputDim) {
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
MaxOpConstModel m({TensorType_FLOAT32, {4, 0, 2}}, {TensorType_FLOAT32, {3}},
{2}, {0, 2}, true);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
}
TEST(DynamicFloatMaxOpTest, NotKeepDims) {
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@ -1128,6 +1169,16 @@ TEST(ConstFloatMinOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({1, 3, 5})));
}
TEST(ConstFloatMinOpTest, ZeroInputDim) {
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
MinOpConstModel m({TensorType_FLOAT32, {4, 0, 2}}, {TensorType_FLOAT32, {3}},
{2}, {0, 2}, true);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
}
TEST(DynamicFloatMinOpTest, NotKeepDims) {
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@ -1338,6 +1389,16 @@ TEST(ConstAnyOpTest, KeepDims) {
EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({true, false, true}));
}
TEST(ConstAnyOpTest, ZeroInputDim) {
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
AnyOpConstModel m({TensorType_BOOL, {2, 0, 2}}, {TensorType_BOOL, {3}}, {2},
{0, 2}, true);
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
}
TEST(DynamicAnyOpTest, NotKeepDims) {
std::vector<bool> data = {false, false, false, false, false, false,
false, true, false, false, false, true};

View File

@ -672,5 +672,25 @@ TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) {
StridedSliceOpModel<TypeParam> m({1, 1, 2}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
m.SetInput({1, 2});
m.SetBegin({1});
m.SetEnd({0});
m.SetStrides({1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
}
TYPED_TEST(StridedSliceOpTest, In3D_Backward) {
StridedSliceOpModel<TypeParam> m({1, 1, 2}, {3}, {3}, {3}, 6, 7, 0, 0, 0);
m.SetInput({1, 2});
m.SetBegin({1, 0, 0});
m.SetEnd({0, -1, -1});
m.SetStrides({1, 1, 1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
}
} // namespace
} // namespace tflite

View File

@ -160,6 +160,14 @@ def make_reduce_tests(reduce_op,
"keepdims": [True, False],
"fully_quantize": [True],
},
{
"input_dtype": [tf.float32],
"input_shape": [[2, 0, 2]],
"axis": [0],
"const_axis": [True],
"keepdims": [True, False],
"fully_quantize": [False],
},
]
# test_parameters include fully_quantize option only when
# allow_fully_quantize is True.

View File

@ -201,6 +201,37 @@ def make_strided_slice_tests(options):
"fully_quantize": [True],
},
]
if options.use_experimental_converter:
test_parameters = test_parameters + [
# Begin equal to input dim.
{
"dtype": [tf.float32],
"index_type": [tf.int32],
"input_shape": [[1, 1, 2]],
"begin": [[1]],
"end": [[0]],
"strides": [[1]],
"begin_mask": [0],
"end_mask": [1],
"shrink_axis_mask": [0],
"constant_indices": [True, False],
"fully_quantize": [False],
},
{
"dtype": [tf.float32],
"index_type": [tf.int32],
"input_shape": [[1, 1, 2]],
"begin": [[1, 0, 0]],
"end": [[0, -1, -1]],
"strides": [[1, 1, 1]],
"begin_mask": [6],
"end_mask": [7],
"shrink_axis_mask": [0],
"constant_indices": [True, False],
"fully_quantize": [False],
}
]
_make_strided_slice_tests(options, test_parameters, expected_tf_failures=2)