Support negative axis for reverse_v2.

PiperOrigin-RevId: 301736084
Change-Id: Id23154b459566d5f85ef0c8d8bc98117c510b963
This commit is contained in:
Renjie Liu 2020-03-18 21:48:13 -07:00 committed by TensorFlower Gardener
parent ffbdfbbe6a
commit 6eca562622
2 changed files with 6 additions and 2 deletions

View File

@ -68,8 +68,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor);
int axis = GetTensorData<int32_t>(axis_tensor)[0];
const int rank = NumDimensions(input);
if (axis < 0) {
axis += rank;
}
TF_LITE_ENSURE(context, axis >= 0 && axis < NumDimensions(input));
TF_LITE_ENSURE(context, axis >= 0 && axis < rank);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (output->type) {

View File

@ -30,7 +30,7 @@ def make_reverse_v2_tests(options):
test_parameters = [{
"dtype": [tf.float32, tf.bool],
"base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
"axis": [0, 1, 2, 3],
"axis": [-2, -1, 0, 1, 2, 3],
}]
def get_valid_axis(parameters):