Support negative axis for reverse_v2.
PiperOrigin-RevId: 301736084 Change-Id: Id23154b459566d5f85ef0c8d8bc98117c510b963
This commit is contained in:
parent
ffbdfbbe6a
commit
6eca562622
@ -68,8 +68,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
|
||||
const TfLiteTensor* axis_tensor = GetInput(context, node, kAxisTensor);
|
||||
int axis = GetTensorData<int32_t>(axis_tensor)[0];
|
||||
const int rank = NumDimensions(input);
|
||||
if (axis < 0) {
|
||||
axis += rank;
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE(context, axis >= 0 && axis < NumDimensions(input));
|
||||
TF_LITE_ENSURE(context, axis >= 0 && axis < rank);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
switch (output->type) {
|
||||
|
@ -30,7 +30,7 @@ def make_reverse_v2_tests(options):
|
||||
test_parameters = [{
|
||||
"dtype": [tf.float32, tf.bool],
|
||||
"base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
|
||||
"axis": [0, 1, 2, 3],
|
||||
"axis": [-2, -1, 0, 1, 2, 3],
|
||||
}]
|
||||
|
||||
def get_valid_axis(parameters):
|
||||
|
Loading…
x
Reference in New Issue
Block a user