Port TransposeConv/float to cpu_backend_gemm.

PiperOrigin-RevId: 247062528
This commit is contained in:
Benoit Jacob 2019-05-07 11:54:03 -07:00 committed by TensorFlower Gardener
parent 3a459ac6c5
commit 581577abf5
3 changed files with 62 additions and 53 deletions

View File

@ -527,6 +527,18 @@ void AddBiasAndEvalActivationFunction(const float* bias_data,
output_activation_max);
}
template <typename Lhs, typename Rhs, typename Result>
void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
Eigen::MatrixBase<Result>* result) {
if (rhs.cols() == 1) {
gemmlowp::ScopedProfilingLabel label("GEMV");
result->col(0).noalias() = lhs * rhs.col(0);
} else {
gemmlowp::ScopedProfilingLabel label("GEMM");
result->noalias() = lhs * rhs;
}
}
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& weights_shape,
@ -2087,6 +2099,28 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
filter_offset, input_offset, output_pipeline);
}
inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeConv");
// Note we could use transposed weights with forward conv for unstrided
// cases. But we are already getting good performance with this code as-is.
TFLITE_DCHECK(im2col_data);
TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
output_shape, im2col_data);
const auto im2col_matrix_map =
MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
const auto filter_matrix_map =
MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
}
inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
int stride_width, int stride_height, int pad_width,

View File

@ -269,18 +269,6 @@ inline void AddBiasAndEvalActivationFunction(float output_activation_min,
#endif
}
template <typename Lhs, typename Rhs, typename Result>
void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
Eigen::MatrixBase<Result>* result) {
if (rhs.cols() == 1) {
gemmlowp::ScopedProfilingLabel label("GEMV");
result->col(0).noalias() = lhs * rhs.col(0);
} else {
gemmlowp::ScopedProfilingLabel label("GEMM");
result->noalias() = lhs * rhs;
}
}
#ifdef GEMMLOWP_NEON
// In the common case of batch size 1, a fully-connected node degenerates
// to a matrix*vector product. LSTM cells contain a fully-connected node;
@ -6301,7 +6289,8 @@ inline void TransposeConvV2(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& hwoi_ordered_filter_shape,
const float* hwoi_ordered_filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& col2im_shape, float* col2im_data) {
float* output_data, const RuntimeShape& col2im_shape, float* col2im_data,
CpuBackendContext* cpu_backend_context) {
gemmlowp::ScopedProfilingLabel label("TransposeConvV2");
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(hwoi_ordered_filter_shape.DimensionsCount(), 4);
@ -6334,21 +6323,25 @@ inline void TransposeConvV2(
const int hwoi_ordered_filter_total_size =
filter_height * filter_width * output_depth;
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
Matrix;
typedef Eigen::Map<Matrix> MatrixRef;
typedef Eigen::Map<const Matrix> ConstMatrixRef;
ConstMatrixRef hwoi_ordered_filter_matrix_map(
hwoi_ordered_filter_data, hwoi_ordered_filter_total_size, input_depth);
cpu_backend_gemm::MatrixParams<float> lhs_params;
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
lhs_params.rows = hwoi_ordered_filter_total_size;
lhs_params.cols = input_depth;
float* output_data_p = output_data;
tensor_utils::ZeroVector(output_data, output_offset * batch_size);
for (int i = 0; i < batch_size; ++i) {
ConstMatrixRef input_matrix_map(input_data + input_offset * i,
input_image_size, input_depth);
MatrixRef output_matrix_map(col2im_data, input_image_size,
hwoi_ordered_filter_total_size);
Gemm(input_matrix_map, hwoi_ordered_filter_matrix_map.transpose(),
&output_matrix_map);
cpu_backend_gemm::MatrixParams<float> rhs_params;
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
rhs_params.rows = input_depth;
rhs_params.cols = input_image_size;
cpu_backend_gemm::MatrixParams<float> dst_params;
dst_params.order = cpu_backend_gemm::Order::kColMajor;
dst_params.rows = hwoi_ordered_filter_total_size;
dst_params.cols = input_image_size;
cpu_backend_gemm::GemmParams<float, float> gemm_params;
cpu_backend_gemm::Gemm(lhs_params, hwoi_ordered_filter_data, rhs_params,
input_data + input_offset * i, dst_params,
col2im_data, gemm_params, cpu_backend_context);
Col2im(col2im_data, output_depth, output_height, output_width,
filter_height, filter_width, padding_top, padding_left,
@ -6358,29 +6351,6 @@ inline void TransposeConvV2(
}
}
// TODO(renjieliu): Investigate whether we need to keep this.
inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& output_shape,
float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeConv");
// Note we could use transposed weights with forward conv for unstrided
// cases. But we are already getting good performance with this code as-is.
TFLITE_DCHECK(im2col_data);
TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
output_shape, im2col_data);
const auto im2col_matrix_map =
MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
const auto filter_matrix_map =
MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
}
// Integer-only version of ResizeNearestNeighbor. Since scales are represented
// in fixed-point and thus approximated, |in_x| or |in_y| may differ from the
// reference version. Debug checks are in place to test if this occurs.

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/cpu_backend_support.h"
#include "tensorflow/lite/kernels/eigen_support.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
@ -85,11 +86,13 @@ struct OpData {
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* data = new OpData;
eigen_support::IncrementUsageCounter(context);
cpu_backend_support::IncrementUsageCounter(context);
return data;
}
void Free(TfLiteContext* context, void* buffer) {
eigen_support::DecrementUsageCounter(context);
cpu_backend_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer);
}
@ -306,8 +309,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <KernelType kernel_type>
void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
const TfLiteTensor* input, const TfLiteTensor* weights,
void EvalFloat(TfLiteContext* context, const TfLiteTransposeConvParams* params,
const OpData* data, const TfLiteTensor* input,
const TfLiteTensor* weights,
const TfLiteTensor* transposed_weights, TfLiteTensor* col2im,
TfLiteTensor* output) {
tflite::ConvParams op_params;
@ -333,7 +337,8 @@ void EvalFloat(const TfLiteTransposeConvParams* params, const OpData* data,
GetTensorShape(transposed_weights),
GetTensorData<float>(transposed_weights), GetTensorShape(output),
GetTensorData<float>(output), GetTensorShape(col2im),
GetTensorData<float>(col2im));
GetTensorData<float>(col2im),
cpu_backend_support::GetFromContext(context));
break;
}
}
@ -419,8 +424,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
ResizeAndTransposeWeights(context, weights, transposed_weights);
}
}
EvalFloat<kernel_type>(params, data, input, weights, transposed_weights,
col2im, output);
EvalFloat<kernel_type>(context, params, data, input, weights,
transposed_weights, col2im, output);
break;
}
case kTfLiteUInt8: {