Further improved the performance of the contrast adjustment code by optimizing the broadcast of scalars
Change: 111631992
This commit is contained in:
parent
339db86e2b
commit
d38fecedf5
@ -38,14 +38,11 @@ struct AdjustContrast {
|
|||||||
Eigen::array<int, 4> scalar_broadcast{{batch, height, width, channels}};
|
Eigen::array<int, 4> scalar_broadcast{{batch, height, width, channels}};
|
||||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||||
Eigen::array<int, 2> reduction_axis{{1, 2}};
|
Eigen::array<int, 2> reduction_axis{{1, 2}};
|
||||||
Eigen::array<int, 4> scalar{{1, 1, 1, 1}};
|
|
||||||
Eigen::array<int, 4> broadcast_dims{{1, height, width, 1}};
|
Eigen::array<int, 4> broadcast_dims{{1, height, width, 1}};
|
||||||
Eigen::Tensor<int, 4>::Dimensions reshape_dims{{batch, 1, 1, channels}};
|
Eigen::Tensor<int, 4>::Dimensions reshape_dims{{batch, 1, 1, channels}};
|
||||||
#else
|
#else
|
||||||
Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2> >
|
Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2> >
|
||||||
reduction_axis;
|
reduction_axis;
|
||||||
Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<1>,
|
|
||||||
Eigen::type2index<1>, Eigen::type2index<1> > scalar;
|
|
||||||
Eigen::IndexList<Eigen::type2index<1>, int, int, Eigen::type2index<1> >
|
Eigen::IndexList<Eigen::type2index<1>, int, int, Eigen::type2index<1> >
|
||||||
broadcast_dims;
|
broadcast_dims;
|
||||||
broadcast_dims.set(1, height);
|
broadcast_dims.set(1, height);
|
||||||
@ -55,6 +52,7 @@ struct AdjustContrast {
|
|||||||
reshape_dims.set(0, batch);
|
reshape_dims.set(0, batch);
|
||||||
reshape_dims.set(3, channels);
|
reshape_dims.set(3, channels);
|
||||||
#endif
|
#endif
|
||||||
|
Eigen::Sizes<1, 1, 1, 1> scalar;
|
||||||
float num_reduced_coeffs = height * width;
|
float num_reduced_coeffs = height * width;
|
||||||
mean_values.device(d) =
|
mean_values.device(d) =
|
||||||
(input.template cast<float>().sum(reduction_axis).eval() /
|
(input.template cast<float>().sum(reduction_axis).eval() /
|
||||||
@ -88,16 +86,12 @@ struct AdjustContrastv2 {
|
|||||||
Eigen::array<int, 4> scalar_broadcast{{batch, height, width, channels}};
|
Eigen::array<int, 4> scalar_broadcast{{batch, height, width, channels}};
|
||||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||||
Eigen::array<int, 2> reduction_axis{{0, 1}};
|
Eigen::array<int, 2> reduction_axis{{0, 1}};
|
||||||
Eigen::array<int, 4> scalar{{1, 1, 1, 1}};
|
|
||||||
Eigen::array<int, 4> broadcast_dims{{1, height, width, 1}};
|
Eigen::array<int, 4> broadcast_dims{{1, height, width, 1}};
|
||||||
Eigen::Tensor<int, 4>::Dimensions reshape_dims{{batch, 1, 1, channels}};
|
Eigen::Tensor<int, 4>::Dimensions reshape_dims{{batch, 1, 1, channels}};
|
||||||
Eigen::array<int, 4> reduced_dims_first{{1, 2, 0, 3}};
|
Eigen::array<int, 4> reduced_dims_first{{1, 2, 0, 3}};
|
||||||
#else
|
#else
|
||||||
Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<1> >
|
Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<1> >
|
||||||
reduction_axis;
|
reduction_axis;
|
||||||
Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<1>,
|
|
||||||
Eigen::type2index<1>, Eigen::type2index<1> >
|
|
||||||
scalar;
|
|
||||||
Eigen::IndexList<Eigen::type2index<1>, int, int, Eigen::type2index<1> >
|
Eigen::IndexList<Eigen::type2index<1>, int, int, Eigen::type2index<1> >
|
||||||
broadcast_dims;
|
broadcast_dims;
|
||||||
broadcast_dims.set(1, height);
|
broadcast_dims.set(1, height);
|
||||||
@ -110,6 +104,7 @@ struct AdjustContrastv2 {
|
|||||||
Eigen::type2index<0>, Eigen::type2index<3> >
|
Eigen::type2index<0>, Eigen::type2index<3> >
|
||||||
reduced_dims_first;
|
reduced_dims_first;
|
||||||
#endif
|
#endif
|
||||||
|
Eigen::Sizes<1, 1, 1, 1> scalar;
|
||||||
float num_reduced_coeffs = height * width;
|
float num_reduced_coeffs = height * width;
|
||||||
output.device(d) =
|
output.device(d) =
|
||||||
(input.shuffle(reduced_dims_first).sum(reduction_axis).eval() /
|
(input.shuffle(reduced_dims_first).sum(reduction_axis).eval() /
|
||||||
|
Loading…
x
Reference in New Issue
Block a user