Keep TransformDepth
This commit is contained in:
parent
f5ebab7d17
commit
e0e50cdd41
tensorflow/core/kernels
@ -173,6 +173,62 @@ struct TransformFilter {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO This functor is not used anywhere and should be removed,
|
||||||
|
// but it defines some eigen templates that are referenced in other kernels.
|
||||||
|
template <typename Device, typename T, typename IndexType>
|
||||||
|
struct TransformDepth {
|
||||||
|
void operator()(const Device& d,
|
||||||
|
typename TTypes<T, 4, IndexType>::ConstTensor in,
|
||||||
|
const Eigen::DSizes<IndexType, 4>& shuffle,
|
||||||
|
typename TTypes<T, 4, IndexType>::Tensor out) {
|
||||||
|
Eigen::DSizes<IndexType, 3> merged_dims;
|
||||||
|
Eigen::DSizes<IndexType, 4> expanded_dims;
|
||||||
|
Eigen::DSizes<IndexType, 3> new_shuffle;
|
||||||
|
|
||||||
|
// Merge dimensions that won't be shuffled together to speed things up.
|
||||||
|
if (shuffle[1] == 2 && shuffle[2] == 3) {
|
||||||
|
merged_dims[0] = in.dimension(0);
|
||||||
|
merged_dims[1] = in.dimension(1);
|
||||||
|
merged_dims[2] = in.dimension(2) * in.dimension(3);
|
||||||
|
new_shuffle[0] = shuffle[0];
|
||||||
|
new_shuffle[1] = 2;
|
||||||
|
new_shuffle[2] = shuffle[3];
|
||||||
|
expanded_dims[0] = in.dimension(shuffle[0]);
|
||||||
|
expanded_dims[1] = in.dimension(2);
|
||||||
|
expanded_dims[2] = in.dimension(3);
|
||||||
|
expanded_dims[3] = in.dimension(shuffle[3]);
|
||||||
|
} else if (shuffle[0] == 2 && shuffle[1] == 3) {
|
||||||
|
merged_dims[0] = in.dimension(0);
|
||||||
|
merged_dims[1] = in.dimension(1);
|
||||||
|
merged_dims[2] = in.dimension(2) * in.dimension(3);
|
||||||
|
new_shuffle[0] = 2;
|
||||||
|
new_shuffle[1] = shuffle[2];
|
||||||
|
new_shuffle[2] = shuffle[3];
|
||||||
|
expanded_dims[0] = in.dimension(2);
|
||||||
|
expanded_dims[1] = in.dimension(3);
|
||||||
|
expanded_dims[2] = in.dimension(shuffle[2]);
|
||||||
|
expanded_dims[3] = in.dimension(shuffle[3]);
|
||||||
|
} else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 &&
|
||||||
|
shuffle[3] == 2) {
|
||||||
|
merged_dims[0] = in.dimension(0);
|
||||||
|
merged_dims[1] = in.dimension(1) * in.dimension(2);
|
||||||
|
merged_dims[2] = in.dimension(3);
|
||||||
|
new_shuffle[0] = 0;
|
||||||
|
new_shuffle[1] = 2;
|
||||||
|
new_shuffle[2] = 1;
|
||||||
|
expanded_dims[0] = in.dimension(0);
|
||||||
|
expanded_dims[1] = in.dimension(3);
|
||||||
|
expanded_dims[2] = in.dimension(1);
|
||||||
|
expanded_dims[3] = in.dimension(2);
|
||||||
|
} else {
|
||||||
|
assert(false && "unexpected shuffle");
|
||||||
|
}
|
||||||
|
|
||||||
|
out.device(d) =
|
||||||
|
in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename Device, typename T, typename IndexType, int NDIMS>
|
template <typename Device, typename T, typename IndexType, int NDIMS>
|
||||||
struct PadInput {
|
struct PadInput {
|
||||||
void operator()(const Device& d,
|
void operator()(const Device& d,
|
||||||
|
@ -29,6 +29,8 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
|
template struct TransformDepth<Eigen::GpuDevice, float, int>;
|
||||||
|
|
||||||
template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float4>;
|
template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float4>;
|
||||||
template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float2,
|
template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float2,
|
||||||
/*conjugate=*/true>;
|
/*conjugate=*/true>;
|
||||||
|
Loading…
Reference in New Issue
Block a user