Port Conv/float to cpu_backend_gemm.
PiperOrigin-RevId: 247073226
This commit is contained in:
parent
e8e5627993
commit
e9710039be
@ -1980,6 +1980,12 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
gemm_input_shape = &input_shape;
|
||||
}
|
||||
|
||||
const int gemm_input_dims = gemm_input_shape->DimensionsCount();
|
||||
int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
|
||||
int n = output_shape.Dims(3);
|
||||
int k = gemm_input_shape->Dims(gemm_input_dims - 1);
|
||||
|
||||
#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
||||
// The following code computes matrix multiplication c = a * transponse(b)
|
||||
// with CBLAS, where:
|
||||
// * `a` is a matrix with dimensions (m, k).
|
||||
@ -1989,12 +1995,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
const float* a = gemm_input_data;
|
||||
const float* b = filter_data;
|
||||
float* c = output_data;
|
||||
const int gemm_input_dims = gemm_input_shape->DimensionsCount();
|
||||
int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
|
||||
int n = output_shape.Dims(3);
|
||||
int k = gemm_input_shape->Dims(gemm_input_dims - 1);
|
||||
|
||||
#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
||||
// The stride of matrix a, b and c respectively.
|
||||
int stride_a = k;
|
||||
int stride_b = k;
|
||||
@ -2002,36 +2002,32 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
||||
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
|
||||
stride_a, b, stride_b, 0.0f, c, stride_c);
|
||||
#else
|
||||
// When an optimized CBLAS implementation is not available, fall back
|
||||
// to using Eigen.
|
||||
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
|
||||
Matrix;
|
||||
typedef Eigen::Map<Matrix> MatrixRef;
|
||||
typedef Eigen::Map<const Matrix> ConstMatrixRef;
|
||||
|
||||
MatrixRef matrix_c(c, m, n);
|
||||
ConstMatrixRef matrix_a(a, m, k);
|
||||
ConstMatrixRef matrix_b(b, n, k);
|
||||
|
||||
// The following special casing for when a or b is a vector is required
|
||||
// as Eigen seem to fail to make this optimization on its own.
|
||||
if (n == 1) {
|
||||
gemmlowp::ScopedProfilingLabel label("GEMV");
|
||||
matrix_c.col(0).noalias() = matrix_a * matrix_b.row(0).transpose();
|
||||
} else if (m == 1) {
|
||||
gemmlowp::ScopedProfilingLabel label("GEMV");
|
||||
matrix_c.row(0).noalias() = matrix_a.row(0) * matrix_b.transpose();
|
||||
} else {
|
||||
gemmlowp::ScopedProfilingLabel label("GEMM");
|
||||
matrix_c.noalias() = matrix_a * matrix_b.transpose();
|
||||
}
|
||||
|
||||
#endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
||||
|
||||
optimized_ops::AddBiasAndEvalActivationFunction(
|
||||
output_activation_min, output_activation_max, bias_shape, bias_data,
|
||||
output_shape, output_data);
|
||||
#else
|
||||
// When an optimized CBLAS implementation is not available, fall back
|
||||
// to using cpu_backend_gemm.
|
||||
cpu_backend_gemm::MatrixParams<float> lhs_params;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.rows = n;
|
||||
lhs_params.cols = k;
|
||||
cpu_backend_gemm::MatrixParams<float> rhs_params;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.rows = k;
|
||||
rhs_params.cols = m;
|
||||
cpu_backend_gemm::MatrixParams<float> dst_params;
|
||||
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
dst_params.rows = n;
|
||||
dst_params.cols = m;
|
||||
cpu_backend_gemm::GemmParams<float, float> gemm_params;
|
||||
gemm_params.bias = bias_data;
|
||||
gemm_params.clamp_min = output_activation_min;
|
||||
gemm_params.clamp_max = output_activation_max;
|
||||
cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
|
||||
dst_params, output_data, gemm_params,
|
||||
cpu_backend_context);
|
||||
#endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
||||
}
|
||||
|
||||
inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
|
||||
|
Loading…
Reference in New Issue
Block a user