Handle edge case in strided slice: begin equal to dim
If input has dim N, begin value with value N should be accepted. Also this change fixes reduce op kernels to handle inputs whose shape has zero dim. PiperOrigin-RevId: 333054231 Change-Id: Ibfc9728db94bd2c16d2033522df1ded24e0eded4
This commit is contained in:
parent
d7b9b383e2
commit
fb578a809c
tensorflow/lite
kernels
testing/op_tests
@ -132,6 +132,11 @@ inline bool ReduceGeneric(const T* input_data, const int* input_dims,
|
||||
bool keep_dims, int* temp_index, int* resolved_axis,
|
||||
T init_value,
|
||||
T reducer(const T current, const T in)) {
|
||||
// Return early when input shape has zero dim.
|
||||
for (int i = 0; i < input_num_dims; ++i) {
|
||||
if (input_dims[i] == 0) return true;
|
||||
}
|
||||
|
||||
// Reset output data.
|
||||
if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value,
|
||||
output_data)) {
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
|
||||
@ -69,8 +70,8 @@ inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
|
||||
}
|
||||
|
||||
// Return the index for the first element along that axis. This index will be a
|
||||
// positive integer between [0, axis_size - 1] that can be used to index
|
||||
// directly into the data.
|
||||
// positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
|
||||
// that can be used to index directly into the data.
|
||||
inline int StartForAxis(const tflite::StridedSliceParams& params,
|
||||
const RuntimeShape& input_shape, int axis) {
|
||||
const auto begin_mask = params.begin_mask;
|
||||
@ -102,7 +103,13 @@ inline int StartForAxis(const tflite::StridedSliceParams& params,
|
||||
}
|
||||
|
||||
// Clamping
|
||||
start = Clamp(start, 0, axis_size - 1);
|
||||
if (strides[axis] > 0) {
|
||||
// Forward iteration
|
||||
start = Clamp(start, 0, axis_size);
|
||||
} else {
|
||||
// Backward iteration
|
||||
start = Clamp(start, -1, axis_size - 1);
|
||||
}
|
||||
|
||||
return start;
|
||||
}
|
||||
|
@ -295,6 +295,11 @@ TfLiteStatus EvalMeanReferenceOps(TfLiteContext* context,
|
||||
op_params.axis_count = num_axis;
|
||||
ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
|
||||
const TfLiteTensor* input = op_context.input;
|
||||
// Return early when input shape has zero dim.
|
||||
for (int i = 0; i < input->dims->size; ++i) {
|
||||
if (input->dims->data[i] == 0) return kTfLiteOk;
|
||||
}
|
||||
|
||||
// TODO(b/139102329): Handle all the cases in the combined reference
|
||||
// method.
|
||||
if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
|
||||
@ -371,14 +376,19 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
|
||||
}
|
||||
|
||||
// Return early when input shape has zero dim.
|
||||
const TfLiteTensor* input = op_context.input;
|
||||
for (int i = 0; i < input->dims->size; ++i) {
|
||||
if (input->dims->data[i] == 0) return kTfLiteOk;
|
||||
}
|
||||
|
||||
if (kernel_type == kGenericOptimized) {
|
||||
// Use optimized ops if available.
|
||||
switch (op_context.input->type) {
|
||||
switch (input->type) {
|
||||
case kTfLiteInt8: {
|
||||
tflite::MeanParams op_params;
|
||||
op_params.axis_count = num_axis;
|
||||
ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
|
||||
const TfLiteTensor* input = op_context.input;
|
||||
if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
|
||||
op_params.axis_count == 2 &&
|
||||
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
|
||||
@ -398,7 +408,6 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
|
||||
tflite::MeanParams op_params;
|
||||
op_params.axis_count = num_axis;
|
||||
ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
|
||||
const TfLiteTensor* input = op_context.input;
|
||||
if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
|
||||
op_params.axis_count == 2 &&
|
||||
((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
|
||||
@ -519,22 +528,28 @@ TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
|
||||
ResizeTempAxis(context, op_context, resolved_axis));
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
|
||||
}
|
||||
if (op_context->input->type == kTfLiteUInt8 ||
|
||||
op_context->input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, op_context->input->params.scale,
|
||||
|
||||
const TfLiteTensor* input = op_context->input;
|
||||
// Return early when input shape has zero dim.
|
||||
for (int i = 0; i < input->dims->size; ++i) {
|
||||
if (input->dims->data[i] == 0) return kTfLiteOk;
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
|
||||
TF_LITE_ENSURE_EQ(context, input->params.scale,
|
||||
op_context->output->params.scale);
|
||||
TF_LITE_ENSURE_EQ(context, op_context->input->params.zero_point,
|
||||
TF_LITE_ENSURE_EQ(context, input->params.zero_point,
|
||||
op_context->output->params.zero_point);
|
||||
}
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
reference_ops::ReduceGeneric<T>(
|
||||
GetTensorData<T>(op_context->input), op_context->input->dims->data,
|
||||
op_context->input->dims->size, GetTensorData<T>(op_context->output),
|
||||
op_context->output->dims->data, op_context->output->dims->size,
|
||||
GetTensorData<int>(op_context->axis), num_axis,
|
||||
op_context->params->keep_dims, GetTensorData<int>(temp_index),
|
||||
GetTensorData<int>(resolved_axis), init_value, reducer));
|
||||
GetTensorData<T>(input), input->dims->data, input->dims->size,
|
||||
GetTensorData<T>(op_context->output), op_context->output->dims->data,
|
||||
op_context->output->dims->size, GetTensorData<int>(op_context->axis),
|
||||
num_axis, op_context->params->keep_dims,
|
||||
GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
|
||||
init_value, reducer));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -658,6 +673,11 @@ TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
|
||||
TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
|
||||
}
|
||||
// Return early when input shape has zero dim.
|
||||
for (int i = 0; i < input->dims->size; ++i) {
|
||||
if (input->dims->data[i] == 0) return kTfLiteOk;
|
||||
}
|
||||
|
||||
if (input->type == kTfLiteUInt8) {
|
||||
TF_LITE_ENSURE(
|
||||
context,
|
||||
|
@ -267,6 +267,17 @@ TEST(ConstFloatMeanOpTest, KeepDims) {
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5})));
|
||||
}
|
||||
|
||||
TEST(ConstFloatMeanOpTest, ZeroInputDim) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
MeanOpConstModel m({TensorType_FLOAT32, {4, 0, 2}}, {TensorType_FLOAT32, {3}},
|
||||
{2}, {0, 2}, true);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
|
||||
}
|
||||
|
||||
// Uses a set of reduction conditions that trigger the specialized 4D version
|
||||
// of Mean.
|
||||
TEST(ConstFloatMeanOpTest, KeepDims4DMean) {
|
||||
@ -663,6 +674,16 @@ TEST(ConstFloatSumOpTest, KeepDims) {
|
||||
ElementsAreArray(ArrayFloatNear({84, 100, 116})));
|
||||
}
|
||||
|
||||
TEST(ConstFloatSumOpTest, ZeroInputDim) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
SumOpConstModel m({TensorType_FLOAT32, {4, 0, 2}}, {TensorType_FLOAT32, {3}},
|
||||
{2}, {0, 2}, true);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(DynamicFloatSumOpTest, NotKeepDims) {
|
||||
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
|
||||
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@ -842,6 +863,16 @@ TEST(ConstFloatProdOpTest, KeepDims) {
|
||||
ArrayFloatNear({7.74592e+06, 1.197504e+08, 6.6889152e+08})));
|
||||
}
|
||||
|
||||
TEST(ConstFloatProdOpTest, ZeroInputDim) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
ProdOpConstModel m({TensorType_FLOAT32, {4, 0, 2}}, {TensorType_FLOAT32, {3}},
|
||||
{2}, {0, 2}, true);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(DynamicFloatProdOpTest, NotKeepDims) {
|
||||
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
|
||||
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@ -915,6 +946,16 @@ TEST(ConstFloatMaxOpTest, KeepDims) {
|
||||
ElementsAreArray(ArrayFloatNear({20, 22, 24})));
|
||||
}
|
||||
|
||||
TEST(ConstFloatMaxOpTest, ZeroInputDim) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
MaxOpConstModel m({TensorType_FLOAT32, {4, 0, 2}}, {TensorType_FLOAT32, {3}},
|
||||
{2}, {0, 2}, true);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(DynamicFloatMaxOpTest, NotKeepDims) {
|
||||
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
|
||||
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@ -1128,6 +1169,16 @@ TEST(ConstFloatMinOpTest, KeepDims) {
|
||||
ElementsAreArray(ArrayFloatNear({1, 3, 5})));
|
||||
}
|
||||
|
||||
TEST(ConstFloatMinOpTest, ZeroInputDim) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
MinOpConstModel m({TensorType_FLOAT32, {4, 0, 2}}, {TensorType_FLOAT32, {3}},
|
||||
{2}, {0, 2}, true);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(DynamicFloatMinOpTest, NotKeepDims) {
|
||||
std::vector<float> data = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
|
||||
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@ -1338,6 +1389,16 @@ TEST(ConstAnyOpTest, KeepDims) {
|
||||
EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({true, false, true}));
|
||||
}
|
||||
|
||||
TEST(ConstAnyOpTest, ZeroInputDim) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
AnyOpConstModel m({TensorType_BOOL, {2, 0, 2}}, {TensorType_BOOL, {3}}, {2},
|
||||
{0, 2}, true);
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(DynamicAnyOpTest, NotKeepDims) {
|
||||
std::vector<bool> data = {false, false, false, false, false, false,
|
||||
false, true, false, false, false, true};
|
||||
|
@ -672,5 +672,25 @@ TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) {
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
|
||||
}
|
||||
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) {
|
||||
StridedSliceOpModel<TypeParam> m({1, 1, 2}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
|
||||
m.SetInput({1, 2});
|
||||
m.SetBegin({1});
|
||||
m.SetEnd({0});
|
||||
m.SetStrides({1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
|
||||
}
|
||||
|
||||
TYPED_TEST(StridedSliceOpTest, In3D_Backward) {
|
||||
StridedSliceOpModel<TypeParam> m({1, 1, 2}, {3}, {3}, {3}, 6, 7, 0, 0, 0);
|
||||
m.SetInput({1, 2});
|
||||
m.SetBegin({1, 0, 0});
|
||||
m.SetEnd({0, -1, -1});
|
||||
m.SetStrides({1, 1, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
@ -160,6 +160,14 @@ def make_reduce_tests(reduce_op,
|
||||
"keepdims": [True, False],
|
||||
"fully_quantize": [True],
|
||||
},
|
||||
{
|
||||
"input_dtype": [tf.float32],
|
||||
"input_shape": [[2, 0, 2]],
|
||||
"axis": [0],
|
||||
"const_axis": [True],
|
||||
"keepdims": [True, False],
|
||||
"fully_quantize": [False],
|
||||
},
|
||||
]
|
||||
# test_parameters include fully_quantize option only when
|
||||
# allow_fully_quantize is True.
|
||||
|
@ -201,6 +201,37 @@ def make_strided_slice_tests(options):
|
||||
"fully_quantize": [True],
|
||||
},
|
||||
]
|
||||
|
||||
if options.use_experimental_converter:
|
||||
test_parameters = test_parameters + [
|
||||
# Begin equal to input dim.
|
||||
{
|
||||
"dtype": [tf.float32],
|
||||
"index_type": [tf.int32],
|
||||
"input_shape": [[1, 1, 2]],
|
||||
"begin": [[1]],
|
||||
"end": [[0]],
|
||||
"strides": [[1]],
|
||||
"begin_mask": [0],
|
||||
"end_mask": [1],
|
||||
"shrink_axis_mask": [0],
|
||||
"constant_indices": [True, False],
|
||||
"fully_quantize": [False],
|
||||
},
|
||||
{
|
||||
"dtype": [tf.float32],
|
||||
"index_type": [tf.int32],
|
||||
"input_shape": [[1, 1, 2]],
|
||||
"begin": [[1, 0, 0]],
|
||||
"end": [[0, -1, -1]],
|
||||
"strides": [[1, 1, 1]],
|
||||
"begin_mask": [6],
|
||||
"end_mask": [7],
|
||||
"shrink_axis_mask": [0],
|
||||
"constant_indices": [True, False],
|
||||
"fully_quantize": [False],
|
||||
}
|
||||
]
|
||||
_make_strided_slice_tests(options, test_parameters, expected_tf_failures=2)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user