Keep TransformDepth

This commit is contained in:
Yuxin Wu 2019-09-06 23:33:59 -07:00
parent f5ebab7d17
commit e0e50cdd41
2 changed files with 58 additions and 0 deletions

View File

@ -173,6 +173,62 @@ struct TransformFilter {
}
};
// TODO This functor is not used anywhere and should be removed,
// but it defines some eigen templates that are referenced in other kernels.
template <typename Device, typename T, typename IndexType>
struct TransformDepth {
void operator()(const Device& d,
typename TTypes<T, 4, IndexType>::ConstTensor in,
const Eigen::DSizes<IndexType, 4>& shuffle,
typename TTypes<T, 4, IndexType>::Tensor out) {
Eigen::DSizes<IndexType, 3> merged_dims;
Eigen::DSizes<IndexType, 4> expanded_dims;
Eigen::DSizes<IndexType, 3> new_shuffle;
// Merge dimensions that won't be shuffled together to speed things up.
if (shuffle[1] == 2 && shuffle[2] == 3) {
merged_dims[0] = in.dimension(0);
merged_dims[1] = in.dimension(1);
merged_dims[2] = in.dimension(2) * in.dimension(3);
new_shuffle[0] = shuffle[0];
new_shuffle[1] = 2;
new_shuffle[2] = shuffle[3];
expanded_dims[0] = in.dimension(shuffle[0]);
expanded_dims[1] = in.dimension(2);
expanded_dims[2] = in.dimension(3);
expanded_dims[3] = in.dimension(shuffle[3]);
} else if (shuffle[0] == 2 && shuffle[1] == 3) {
merged_dims[0] = in.dimension(0);
merged_dims[1] = in.dimension(1);
merged_dims[2] = in.dimension(2) * in.dimension(3);
new_shuffle[0] = 2;
new_shuffle[1] = shuffle[2];
new_shuffle[2] = shuffle[3];
expanded_dims[0] = in.dimension(2);
expanded_dims[1] = in.dimension(3);
expanded_dims[2] = in.dimension(shuffle[2]);
expanded_dims[3] = in.dimension(shuffle[3]);
} else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 &&
shuffle[3] == 2) {
merged_dims[0] = in.dimension(0);
merged_dims[1] = in.dimension(1) * in.dimension(2);
merged_dims[2] = in.dimension(3);
new_shuffle[0] = 0;
new_shuffle[1] = 2;
new_shuffle[2] = 1;
expanded_dims[0] = in.dimension(0);
expanded_dims[1] = in.dimension(3);
expanded_dims[2] = in.dimension(1);
expanded_dims[3] = in.dimension(2);
} else {
assert(false && "unexpected shuffle");
}
out.device(d) =
in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims);
}
};
template <typename Device, typename T, typename IndexType, int NDIMS>
struct PadInput {
void operator()(const Device& d,

View File

@ -29,6 +29,8 @@ namespace tensorflow {
namespace functor {
template struct TransformDepth<Eigen::GpuDevice, float, int>;
template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float4>;
template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float2,
/*conjugate=*/true>;