Port TransposeConv/float to cpu_backend_gemm.

PiperOrigin-RevId: 247062528
This commit is contained in:
Benoit Jacob 2019-05-07 11:54:03 -07:00 committed by TensorFlower Gardener
parent 3a459ac6c5
commit 581577abf5
3 changed files with 62 additions and 53 deletions

View File

@ -527,6 +527,18 @@ void AddBiasAndEvalActivationFunction(const float* bias_data,
output_activation_max); output_activation_max);
} }
template <typename Lhs, typename Rhs, typename Result>
void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
Eigen::MatrixBase<Result>* result) {
if (rhs.cols() == 1) {
gemmlowp::ScopedProfilingLabel label("GEMV");
result->col(0).noalias() = lhs * rhs.col(0);
} else {
gemmlowp::ScopedProfilingLabel label("GEMM");
result->noalias() = lhs * rhs;
}
}
inline void FullyConnected( inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape, const FullyConnectedParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& weights_shape, const float* input_data, const RuntimeShape& weights_shape,
@ -2087,6 +2099,28 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
filter_offset, input_offset, output_pipeline); filter_offset, input_offset, output_pipeline);
} }
inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeConv");
// Note we could use transposed weights with forward conv for unstrided
// cases. But we are already getting good performance with this code as-is.
TFLITE_DCHECK(im2col_data);
TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
output_shape, im2col_data);
const auto im2col_matrix_map =
MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
const auto filter_matrix_map =
MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
}
inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims, const float* filter_data, const Dims<4>& filter_dims,
int stride_width, int stride_height, int pad_width, int stride_width, int stride_height, int pad_width,

View File

@ -269,18 +269,6 @@ inline void AddBiasAndEvalActivationFunction(float output_activation_min,
#endif #endif
} }
template <typename Lhs, typename Rhs, typename Result>
void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
Eigen::MatrixBase<Result>* result) {
if (rhs.cols() == 1) {
gemmlowp::ScopedProfilingLabel label("GEMV");
result->col(0).noalias() = lhs * rhs.col(0);
} else {
gemmlowp::ScopedProfilingLabel label("GEMM");
result->noalias() = lhs * rhs;
}
}
#ifdef GEMMLOWP_NEON #ifdef GEMMLOWP_NEON
// In the common case of batch size 1, a fully-connected node degenerates // In the common case of batch size 1, a fully-connected node degenerates
// to a matrix*vector product. LSTM cells contain a fully-connected node; // to a matrix*vector product. LSTM cells contain a fully-connected node;
@ -6301,7 +6289,8 @@ inline void TransposeConvV2(
const ConvParams& params, const RuntimeShape& input_shape, const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape, const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape, const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& col2im_shape, float* col2im_data) { float* output_data, const RuntimeShape& col2im_shape, float* col2im_data,
CpuBackendContext* cpu_backend_context) {
gemmlowp::ScopedProfilingLabel label("TransposeConvV2"); gemmlowp::ScopedProfilingLabel label("TransposeConvV2");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
@ -6334,21 +6323,25 @@ inline void TransposeConvV2(
const int hwoi_ordered_filter_total_size = const int hwoi_ordered_filter_total_size =
filter_height * filter_width * output_depth; filter_height * filter_width * output_depth;
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> cpu_backend_gemm::MatrixParams<float> lhs_params;
Matrix; lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
typedef Eigen::Map<Matrix> MatrixRef; lhs_params.rows = hwoi_ordered_filter_total_size;
typedef Eigen::Map<const Matrix> ConstMatrixRef; lhs_params.cols = input_depth;
ConstMatrixRef hwoi_ordered_filter_matrix_map(
hwoi_ordered_filter_data, hwoi_ordered_filter_total_size, input_depth);
float* output_data_p = output_data; float* output_data_p = output_data;
tensor_utils::ZeroVector(output_data, output_offset * batch_size); tensor_utils::ZeroVector(output_data, output_offset * batch_size);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
ConstMatrixRef input_matrix_map(input_data + input_offset * i, cpu_backend_gemm::MatrixParams<float> rhs_params;
input_image_size, input_depth); rhs_params.order = cpu_backend_gemm::Order::kColMajor;
MatrixRef output_matrix_map(col2im_data, input_image_size, rhs_params.rows = input_depth;
hwoi_ordered_filter_total_size); rhs_params.cols = input_image_size;
Gemm(input_matrix_map, hwoi_ordered_filter_matrix_map.transpose(), cpu_backend_gemm::MatrixParams<float> dst_params;
&output_matrix_map); dst_params.order = cpu_backend_gemm::Order::kColMajor;
dst_params.rows = hwoi_ordered_filter_total_size;
dst_params.cols = input_image_size;
cpu_backend_gemm::GemmParams<float, float> gemm_params;
cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
input_data + input_offset * i, dst_params,
col2im_data, gemm_params, cpu_backend_context);
Col2im(col2im_data, output_depth, output_height, output_width, Col2im(col2im_data, output_depth, output_height, output_width,
filter_height, filter_width, padding_top, padding_left, filter_height, filter_width, padding_top, padding_left,
@ -6358,29 +6351,6 @@ inline void TransposeConvV2(
} }
} }
// TODO(renjieliu): Investigate whether we need to keep this.
inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeConv");
// Note we could use transposed weights with forward conv for unstrided
// cases. But we are already getting good performance with this code as-is.
TFLITE_DCHECK(im2col_data);
TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
output_shape, im2col_data);
const auto im2col_matrix_map =
MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
const auto filter_matrix_map =
MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
}
// Integer-only version of ResizeNearestNeighbor. Since scales are represented // Integer-only version of ResizeNearestNeighbor. Since scales are represented
// in fixed-point and thus approximated, |in_x| or |in_y| may differ from the // in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
// reference version. Debug checks are in place to test if this occurs. // reference version. Debug checks are in place to test if this occurs.

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h" #include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/cpu_backend_support.h"
#include "tensorflow/lite/kernels/eigen_support.h" #include "tensorflow/lite/kernels/eigen_support.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h"
@ -85,11 +86,13 @@ struct OpData {
void* Init(TfLiteContext* context, const char* buffer, size_t length) { void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* data = new OpData; auto* data = new OpData;
eigen_support::IncrementUsageCounter(context); eigen_support::IncrementUsageCounter(context);
cpu_backend_support::IncrementUsageCounter(context);
return data; return data;
} }
void Free(TfLiteContext* context, void* buffer) { void Free(TfLiteContext* context, void* buffer) {
eigen_support::DecrementUsageCounter(context); eigen_support::DecrementUsageCounter(context);
cpu_backend_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer); delete reinterpret_cast<OpData*>(buffer);
} }
@ -306,8 +309,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
template <KernelType kernel_type> template <KernelType kernel_type>
void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data, void EvalFloat(TfLiteContext* context, const TfLiteTransposeConvParams* params,
const TfLiteTensor* input, const TfLiteTensor* weights, const OpData* data, const TfLiteTensor* input,
const TfLiteTensor* weights,
const TfLiteTensor* transposed_weights, TfLiteTensor* col2im, const TfLiteTensor* transposed_weights, TfLiteTensor* col2im,
TfLiteTensor* output) { TfLiteTensor* output) {
tflite::ConvParams op_params; tflite::ConvParams op_params;
@ -333,7 +337,8 @@ void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
GetTensorShape(transposed_weights), GetTensorShape(transposed_weights),
GetTensorData<float>(transposed_weights), GetTensorShape(output), GetTensorData<float>(transposed_weights), GetTensorShape(output),
GetTensorData<float>(output), GetTensorShape(col2im), GetTensorData<float>(output), GetTensorShape(col2im),
GetTensorData<float>(col2im)); GetTensorData<float>(col2im),
cpu_backend_support::GetFromContext(context));
break; break;
} }
} }
@ -419,8 +424,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
ResizeAndTransposeWeights(context, weights, transposed_weights); ResizeAndTransposeWeights(context, weights, transposed_weights);
} }
} }
EvalFloat<kernel_type>(params, data, input, weights, transposed_weights, EvalFloat<kernel_type>(context, params, data, input, weights,
col2im, output); transposed_weights, col2im, output);
break; break;
} }
case kTfLiteUInt8: { case kTfLiteUInt8: {