Keep TransformDepth
This commit is contained in:
parent
f5ebab7d17
commit
e0e50cdd41
@ -173,6 +173,62 @@ struct TransformFilter {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO This functor is not used anywhere and should be removed,
|
||||
// but it defines some eigen templates that are referenced in other kernels.
|
||||
template <typename Device, typename T, typename IndexType>
|
||||
struct TransformDepth {
|
||||
void operator()(const Device& d,
|
||||
typename TTypes<T, 4, IndexType>::ConstTensor in,
|
||||
const Eigen::DSizes<IndexType, 4>& shuffle,
|
||||
typename TTypes<T, 4, IndexType>::Tensor out) {
|
||||
Eigen::DSizes<IndexType, 3> merged_dims;
|
||||
Eigen::DSizes<IndexType, 4> expanded_dims;
|
||||
Eigen::DSizes<IndexType, 3> new_shuffle;
|
||||
|
||||
// Merge dimensions that won't be shuffled together to speed things up.
|
||||
if (shuffle[1] == 2 && shuffle[2] == 3) {
|
||||
merged_dims[0] = in.dimension(0);
|
||||
merged_dims[1] = in.dimension(1);
|
||||
merged_dims[2] = in.dimension(2) * in.dimension(3);
|
||||
new_shuffle[0] = shuffle[0];
|
||||
new_shuffle[1] = 2;
|
||||
new_shuffle[2] = shuffle[3];
|
||||
expanded_dims[0] = in.dimension(shuffle[0]);
|
||||
expanded_dims[1] = in.dimension(2);
|
||||
expanded_dims[2] = in.dimension(3);
|
||||
expanded_dims[3] = in.dimension(shuffle[3]);
|
||||
} else if (shuffle[0] == 2 && shuffle[1] == 3) {
|
||||
merged_dims[0] = in.dimension(0);
|
||||
merged_dims[1] = in.dimension(1);
|
||||
merged_dims[2] = in.dimension(2) * in.dimension(3);
|
||||
new_shuffle[0] = 2;
|
||||
new_shuffle[1] = shuffle[2];
|
||||
new_shuffle[2] = shuffle[3];
|
||||
expanded_dims[0] = in.dimension(2);
|
||||
expanded_dims[1] = in.dimension(3);
|
||||
expanded_dims[2] = in.dimension(shuffle[2]);
|
||||
expanded_dims[3] = in.dimension(shuffle[3]);
|
||||
} else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 &&
|
||||
shuffle[3] == 2) {
|
||||
merged_dims[0] = in.dimension(0);
|
||||
merged_dims[1] = in.dimension(1) * in.dimension(2);
|
||||
merged_dims[2] = in.dimension(3);
|
||||
new_shuffle[0] = 0;
|
||||
new_shuffle[1] = 2;
|
||||
new_shuffle[2] = 1;
|
||||
expanded_dims[0] = in.dimension(0);
|
||||
expanded_dims[1] = in.dimension(3);
|
||||
expanded_dims[2] = in.dimension(1);
|
||||
expanded_dims[3] = in.dimension(2);
|
||||
} else {
|
||||
assert(false && "unexpected shuffle");
|
||||
}
|
||||
|
||||
out.device(d) =
|
||||
in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T, typename IndexType, int NDIMS>
|
||||
struct PadInput {
|
||||
void operator()(const Device& d,
|
||||
|
@ -29,6 +29,8 @@ namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
||||
template struct TransformDepth<Eigen::GpuDevice, float, int>;
|
||||
|
||||
template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float4>;
|
||||
template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float2,
|
||||
/*conjugate=*/true>;
|
||||
|
Loading…
Reference in New Issue
Block a user