Port TransposeConv/float to cpu_backend_gemm.
PiperOrigin-RevId: 247062528
This commit is contained in:
parent
3a459ac6c5
commit
581577abf5
@ -527,6 +527,18 @@ void AddBiasAndEvalActivationFunction(const float* bias_data,
|
|||||||
output_activation_max);
|
output_activation_max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Lhs, typename Rhs, typename Result>
|
||||||
|
void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
|
||||||
|
Eigen::MatrixBase<Result>* result) {
|
||||||
|
if (rhs.cols() == 1) {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("GEMV");
|
||||||
|
result->col(0).noalias() = lhs * rhs.col(0);
|
||||||
|
} else {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("GEMM");
|
||||||
|
result->noalias() = lhs * rhs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline void FullyConnected(
|
inline void FullyConnected(
|
||||||
const FullyConnectedParams& params, const RuntimeShape& input_shape,
|
const FullyConnectedParams& params, const RuntimeShape& input_shape,
|
||||||
const float* input_data, const RuntimeShape& weights_shape,
|
const float* input_data, const RuntimeShape& weights_shape,
|
||||||
@ -2087,6 +2099,28 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
|
|||||||
filter_offset, input_offset, output_pipeline);
|
filter_offset, input_offset, output_pipeline);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void TransposeConv(
|
||||||
|
const ConvParams& params, const RuntimeShape& input_shape,
|
||||||
|
const float* input_data, const RuntimeShape& filter_shape,
|
||||||
|
const float* filter_data, const RuntimeShape& output_shape,
|
||||||
|
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("TransposeConv");
|
||||||
|
// Note we could use transposed weights with forward conv for unstrided
|
||||||
|
// cases. But we are already getting good performance with this code as-is.
|
||||||
|
TFLITE_DCHECK(im2col_data);
|
||||||
|
TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
|
||||||
|
output_shape, im2col_data);
|
||||||
|
|
||||||
|
const auto im2col_matrix_map =
|
||||||
|
MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
|
||||||
|
const auto filter_matrix_map =
|
||||||
|
MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
|
||||||
|
auto output_matrix_map =
|
||||||
|
MapAsMatrixWithLastDimAsRows(output_data, output_shape);
|
||||||
|
|
||||||
|
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
|
||||||
|
}
|
||||||
|
|
||||||
inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
|
inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
|
||||||
const float* filter_data, const Dims<4>& filter_dims,
|
const float* filter_data, const Dims<4>& filter_dims,
|
||||||
int stride_width, int stride_height, int pad_width,
|
int stride_width, int stride_height, int pad_width,
|
||||||
|
@ -269,18 +269,6 @@ inline void AddBiasAndEvalActivationFunction(float output_activation_min,
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Lhs, typename Rhs, typename Result>
|
|
||||||
void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
|
|
||||||
Eigen::MatrixBase<Result>* result) {
|
|
||||||
if (rhs.cols() == 1) {
|
|
||||||
gemmlowp::ScopedProfilingLabel label("GEMV");
|
|
||||||
result->col(0).noalias() = lhs * rhs.col(0);
|
|
||||||
} else {
|
|
||||||
gemmlowp::ScopedProfilingLabel label("GEMM");
|
|
||||||
result->noalias() = lhs * rhs;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef GEMMLOWP_NEON
|
#ifdef GEMMLOWP_NEON
|
||||||
// In the common case of batch size 1, a fully-connected node degenerates
|
// In the common case of batch size 1, a fully-connected node degenerates
|
||||||
// to a matrix*vector product. LSTM cells contain a fully-connected node;
|
// to a matrix*vector product. LSTM cells contain a fully-connected node;
|
||||||
@ -6301,7 +6289,8 @@ inline void TransposeConvV2(
|
|||||||
const ConvParams& params, const RuntimeShape& input_shape,
|
const ConvParams& params, const RuntimeShape& input_shape,
|
||||||
const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
|
const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
|
||||||
const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
|
const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
|
||||||
float* output_data, const RuntimeShape& col2im_shape, float* col2im_data) {
|
float* output_data, const RuntimeShape& col2im_shape, float* col2im_data,
|
||||||
|
CpuBackendContext* cpu_backend_context) {
|
||||||
gemmlowp::ScopedProfilingLabel label("TransposeConvV2");
|
gemmlowp::ScopedProfilingLabel label("TransposeConvV2");
|
||||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
|
||||||
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
|
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
|
||||||
@ -6334,21 +6323,25 @@ inline void TransposeConvV2(
|
|||||||
const int hwoi_ordered_filter_total_size =
|
const int hwoi_ordered_filter_total_size =
|
||||||
filter_height * filter_width * output_depth;
|
filter_height * filter_width * output_depth;
|
||||||
|
|
||||||
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
|
cpu_backend_gemm::MatrixParams<float> lhs_params;
|
||||||
Matrix;
|
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||||
typedef Eigen::Map<Matrix> MatrixRef;
|
lhs_params.rows = hwoi_ordered_filter_total_size;
|
||||||
typedef Eigen::Map<const Matrix> ConstMatrixRef;
|
lhs_params.cols = input_depth;
|
||||||
ConstMatrixRef hwoi_ordered_filter_matrix_map(
|
|
||||||
hwoi_ordered_filter_data, hwoi_ordered_filter_total_size, input_depth);
|
|
||||||
float* output_data_p = output_data;
|
float* output_data_p = output_data;
|
||||||
tensor_utils::ZeroVector(output_data, output_offset * batch_size);
|
tensor_utils::ZeroVector(output_data, output_offset * batch_size);
|
||||||
for (int i = 0; i < batch_size; ++i) {
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
ConstMatrixRef input_matrix_map(input_data + input_offset * i,
|
cpu_backend_gemm::MatrixParams<float> rhs_params;
|
||||||
input_image_size, input_depth);
|
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
MatrixRef output_matrix_map(col2im_data, input_image_size,
|
rhs_params.rows = input_depth;
|
||||||
hwoi_ordered_filter_total_size);
|
rhs_params.cols = input_image_size;
|
||||||
Gemm(input_matrix_map, hwoi_ordered_filter_matrix_map.transpose(),
|
cpu_backend_gemm::MatrixParams<float> dst_params;
|
||||||
&output_matrix_map);
|
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
|
dst_params.rows = hwoi_ordered_filter_total_size;
|
||||||
|
dst_params.cols = input_image_size;
|
||||||
|
cpu_backend_gemm::GemmParams<float, float> gemm_params;
|
||||||
|
cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
|
||||||
|
input_data + input_offset * i, dst_params,
|
||||||
|
col2im_data, gemm_params, cpu_backend_context);
|
||||||
|
|
||||||
Col2im(col2im_data, output_depth, output_height, output_width,
|
Col2im(col2im_data, output_depth, output_height, output_width,
|
||||||
filter_height, filter_width, padding_top, padding_left,
|
filter_height, filter_width, padding_top, padding_left,
|
||||||
@ -6358,29 +6351,6 @@ inline void TransposeConvV2(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(renjieliu): Investigate whether we need to keep this.
|
|
||||||
inline void TransposeConv(
|
|
||||||
const ConvParams& params, const RuntimeShape& input_shape,
|
|
||||||
const float* input_data, const RuntimeShape& filter_shape,
|
|
||||||
const float* filter_data, const RuntimeShape& output_shape,
|
|
||||||
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
|
|
||||||
gemmlowp::ScopedProfilingLabel label("TransposeConv");
|
|
||||||
// Note we could use transposed weights with forward conv for unstrided
|
|
||||||
// cases. But we are already getting good performance with this code as-is.
|
|
||||||
TFLITE_DCHECK(im2col_data);
|
|
||||||
TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
|
|
||||||
output_shape, im2col_data);
|
|
||||||
|
|
||||||
const auto im2col_matrix_map =
|
|
||||||
MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
|
|
||||||
const auto filter_matrix_map =
|
|
||||||
MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
|
|
||||||
auto output_matrix_map =
|
|
||||||
MapAsMatrixWithLastDimAsRows(output_data, output_shape);
|
|
||||||
|
|
||||||
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Integer-only version of ResizeNearestNeighbor. Since scales are represented
|
// Integer-only version of ResizeNearestNeighbor. Since scales are represented
|
||||||
// in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
|
// in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
|
||||||
// reference version. Debug checks are in place to test if this occurs.
|
// reference version. Debug checks are in place to test if this occurs.
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/c_api_internal.h"
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/lite/kernels/cpu_backend_support.h"
|
||||||
#include "tensorflow/lite/kernels/eigen_support.h"
|
#include "tensorflow/lite/kernels/eigen_support.h"
|
||||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
@ -85,11 +86,13 @@ struct OpData {
|
|||||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
auto* data = new OpData;
|
auto* data = new OpData;
|
||||||
eigen_support::IncrementUsageCounter(context);
|
eigen_support::IncrementUsageCounter(context);
|
||||||
|
cpu_backend_support::IncrementUsageCounter(context);
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Free(TfLiteContext* context, void* buffer) {
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
eigen_support::DecrementUsageCounter(context);
|
eigen_support::DecrementUsageCounter(context);
|
||||||
|
cpu_backend_support::DecrementUsageCounter(context);
|
||||||
delete reinterpret_cast<OpData*>(buffer);
|
delete reinterpret_cast<OpData*>(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -306,8 +309,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
template <KernelType kernel_type>
|
||||||
void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
|
void EvalFloat(TfLiteContext* context, const TfLiteTransposeConvParams* params,
|
||||||
const TfLiteTensor* input, const TfLiteTensor* weights,
|
const OpData* data, const TfLiteTensor* input,
|
||||||
|
const TfLiteTensor* weights,
|
||||||
const TfLiteTensor* transposed_weights, TfLiteTensor* col2im,
|
const TfLiteTensor* transposed_weights, TfLiteTensor* col2im,
|
||||||
TfLiteTensor* output) {
|
TfLiteTensor* output) {
|
||||||
tflite::ConvParams op_params;
|
tflite::ConvParams op_params;
|
||||||
@ -333,7 +337,8 @@ void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
|
|||||||
GetTensorShape(transposed_weights),
|
GetTensorShape(transposed_weights),
|
||||||
GetTensorData<float>(transposed_weights), GetTensorShape(output),
|
GetTensorData<float>(transposed_weights), GetTensorShape(output),
|
||||||
GetTensorData<float>(output), GetTensorShape(col2im),
|
GetTensorData<float>(output), GetTensorShape(col2im),
|
||||||
GetTensorData<float>(col2im));
|
GetTensorData<float>(col2im),
|
||||||
|
cpu_backend_support::GetFromContext(context));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -419,8 +424,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
ResizeAndTransposeWeights(context, weights, transposed_weights);
|
ResizeAndTransposeWeights(context, weights, transposed_weights);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EvalFloat<kernel_type>(params, data, input, weights, transposed_weights,
|
EvalFloat<kernel_type>(context, params, data, input, weights,
|
||||||
col2im, output);
|
transposed_weights, col2im, output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case kTfLiteUInt8: {
|
case kTfLiteUInt8: {
|
||||||
|
Loading…
Reference in New Issue
Block a user