parent
70ee164522
commit
f171becb02
@ -1980,12 +1980,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
|||||||
gemm_input_shape = &input_shape;
|
gemm_input_shape = &input_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int gemm_input_dims = gemm_input_shape->DimensionsCount();
|
|
||||||
int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
|
|
||||||
int n = output_shape.Dims(3);
|
|
||||||
int k = gemm_input_shape->Dims(gemm_input_dims - 1);
|
|
||||||
|
|
||||||
#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
|
||||||
// The following code computes matrix multiplication c = a * transponse(b)
|
// The following code computes matrix multiplication c = a * transponse(b)
|
||||||
// with CBLAS, where:
|
// with CBLAS, where:
|
||||||
// * `a` is a matrix with dimensions (m, k).
|
// * `a` is a matrix with dimensions (m, k).
|
||||||
@ -1995,6 +1989,12 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
|||||||
const float* a = gemm_input_data;
|
const float* a = gemm_input_data;
|
||||||
const float* b = filter_data;
|
const float* b = filter_data;
|
||||||
float* c = output_data;
|
float* c = output_data;
|
||||||
|
const int gemm_input_dims = gemm_input_shape->DimensionsCount();
|
||||||
|
int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
|
||||||
|
int n = output_shape.Dims(3);
|
||||||
|
int k = gemm_input_shape->Dims(gemm_input_dims - 1);
|
||||||
|
|
||||||
|
#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
||||||
// The stride of matrix a, b and c respectively.
|
// The stride of matrix a, b and c respectively.
|
||||||
int stride_a = k;
|
int stride_a = k;
|
||||||
int stride_b = k;
|
int stride_b = k;
|
||||||
@ -2002,32 +2002,36 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
|
|||||||
|
|
||||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
|
||||||
stride_a, b, stride_b, 0.0f, c, stride_c);
|
stride_a, b, stride_b, 0.0f, c, stride_c);
|
||||||
|
#else
|
||||||
|
// When an optimized CBLAS implementation is not available, fall back
|
||||||
|
// to using Eigen.
|
||||||
|
typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
|
||||||
|
Matrix;
|
||||||
|
typedef Eigen::Map<Matrix> MatrixRef;
|
||||||
|
typedef Eigen::Map<const Matrix> ConstMatrixRef;
|
||||||
|
|
||||||
|
MatrixRef matrix_c(c, m, n);
|
||||||
|
ConstMatrixRef matrix_a(a, m, k);
|
||||||
|
ConstMatrixRef matrix_b(b, n, k);
|
||||||
|
|
||||||
|
// The following special casing for when a or b is a vector is required
|
||||||
|
// as Eigen seem to fail to make this optimization on its own.
|
||||||
|
if (n == 1) {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("GEMV");
|
||||||
|
matrix_c.col(0).noalias() = matrix_a * matrix_b.row(0).transpose();
|
||||||
|
} else if (m == 1) {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("GEMV");
|
||||||
|
matrix_c.row(0).noalias() = matrix_a.row(0) * matrix_b.transpose();
|
||||||
|
} else {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("GEMM");
|
||||||
|
matrix_c.noalias() = matrix_a * matrix_b.transpose();
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
||||||
|
|
||||||
optimized_ops::AddBiasAndEvalActivationFunction(
|
optimized_ops::AddBiasAndEvalActivationFunction(
|
||||||
output_activation_min, output_activation_max, bias_shape, bias_data,
|
output_activation_min, output_activation_max, bias_shape, bias_data,
|
||||||
output_shape, output_data);
|
output_shape, output_data);
|
||||||
#else
|
|
||||||
// When an optimized CBLAS implementation is not available, fall back
|
|
||||||
// to using cpu_backend_gemm.
|
|
||||||
cpu_backend_gemm::MatrixParams<float> lhs_params;
|
|
||||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
|
||||||
lhs_params.rows = n;
|
|
||||||
lhs_params.cols = k;
|
|
||||||
cpu_backend_gemm::MatrixParams<float> rhs_params;
|
|
||||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
|
||||||
rhs_params.rows = k;
|
|
||||||
rhs_params.cols = m;
|
|
||||||
cpu_backend_gemm::MatrixParams<float> dst_params;
|
|
||||||
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
|
||||||
dst_params.rows = n;
|
|
||||||
dst_params.cols = m;
|
|
||||||
cpu_backend_gemm::GemmParams<float, float> gemm_params;
|
|
||||||
gemm_params.bias = bias_data;
|
|
||||||
gemm_params.clamp_min = output_activation_min;
|
|
||||||
gemm_params.clamp_max = output_activation_max;
|
|
||||||
cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, gemm_input_data,
|
|
||||||
dst_params, output_data, gemm_params,
|
|
||||||
cpu_backend_context);
|
|
||||||
#endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
|
inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
|
||||||
|
Loading…
Reference in New Issue
Block a user