STT-tensorflow/tensorflow/lite/kernels/internal/transpose_utils.h
Jaesung Chung 37248520e4 Implement general dimension reduction optimizations
PiperOrigin-RevId: 276480378
Change-Id: Iaf8f9228c51cf4f39c95c7f66d02abea2ad2b9fa
2019-10-24 07:09:09 -07:00

53 lines
2.3 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace transpose_utils {
// IsTranspose2DApplicable returns true if the given perm can be lowered to a
// 2D transpose op. If possible, it copies the lowered dimension counts to the
// given dim0 and dim1 pointers.
bool IsTranspose2DApplicable(const TransposeParams& params,
const RuntimeShape& input_shape, int* dim0,
int* dim1);
// RemoveOneSizeDimensions removes one size dimensions in the given input/output
// shapes and adjusts the parameter values for transpose op.
void RemoveOneSizeDimensions(RuntimeShape* input_shape,
RuntimeShape* output_shape,
TransposeParams* params);
// Flatten finds the dimensions that can be flatten, shrinks the given shapes
// and the given perm parameter to reflect the non-flatten dimensions, and
// returns the total size of the non-flatten dimensions.
//
// E.g, in perm [0, 1, 3, 2] case, the first two dimensions can be flatten and
// it returns |Dim Size(2)| x |Dim Size(3)|.
size_t Flatten(const RuntimeShape& input_shape,
const RuntimeShape& output_shape, const TransposeParams& params,
RuntimeShape* non_flatten_input_shape,
RuntimeShape* non_flatten_output_shape,
TransposeParams* non_flatten_params);
} // namespace transpose_utils
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TRANSPOSE_UTILS_H_