diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index dda24210a0b..5c7e75d50e7 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -173,6 +173,62 @@ struct TransformFilter { } }; +// TODO This functor is not used anywhere and should be removed, +// but it defines some eigen templates that are referenced in other kernels. +template <typename Device, typename T, typename IndexType> +struct TransformDepth { + void operator()(const Device& d, + typename TTypes<T, 4, IndexType>::ConstTensor in, + const Eigen::DSizes<IndexType, 4>& shuffle, + typename TTypes<T, 4, IndexType>::Tensor out) { + Eigen::DSizes<IndexType, 3> merged_dims; + Eigen::DSizes<IndexType, 4> expanded_dims; + Eigen::DSizes<IndexType, 3> new_shuffle; + + // Merge dimensions that won't be shuffled together to speed things up. + if (shuffle[1] == 2 && shuffle[2] == 3) { + merged_dims[0] = in.dimension(0); + merged_dims[1] = in.dimension(1); + merged_dims[2] = in.dimension(2) * in.dimension(3); + new_shuffle[0] = shuffle[0]; + new_shuffle[1] = 2; + new_shuffle[2] = shuffle[3]; + expanded_dims[0] = in.dimension(shuffle[0]); + expanded_dims[1] = in.dimension(2); + expanded_dims[2] = in.dimension(3); + expanded_dims[3] = in.dimension(shuffle[3]); + } else if (shuffle[0] == 2 && shuffle[1] == 3) { + merged_dims[0] = in.dimension(0); + merged_dims[1] = in.dimension(1); + merged_dims[2] = in.dimension(2) * in.dimension(3); + new_shuffle[0] = 2; + new_shuffle[1] = shuffle[2]; + new_shuffle[2] = shuffle[3]; + expanded_dims[0] = in.dimension(2); + expanded_dims[1] = in.dimension(3); + expanded_dims[2] = in.dimension(shuffle[2]); + expanded_dims[3] = in.dimension(shuffle[3]); + } else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 && + shuffle[3] == 2) { + merged_dims[0] = in.dimension(0); + merged_dims[1] = in.dimension(1) * in.dimension(2); + merged_dims[2] = in.dimension(3); + new_shuffle[0] = 0; + new_shuffle[1] = 2; + new_shuffle[2] = 1; + expanded_dims[0] = in.dimension(0); + expanded_dims[1] = in.dimension(3); + expanded_dims[2] = in.dimension(1); + expanded_dims[3] = in.dimension(2); + } else { + assert(false && "unexpected shuffle"); + } + + out.device(d) = + in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims); + } +}; + template <typename Device, typename T, typename IndexType, int NDIMS> struct PadInput { void operator()(const Device& d, diff --git a/tensorflow/core/kernels/conv_2d_gpu_float.cu.cc b/tensorflow/core/kernels/conv_2d_gpu_float.cu.cc index 53137979ca6..9c92d1f700f 100644 --- a/tensorflow/core/kernels/conv_2d_gpu_float.cu.cc +++ b/tensorflow/core/kernels/conv_2d_gpu_float.cu.cc @@ -29,6 +29,8 @@ namespace tensorflow { namespace functor { +template struct TransformDepth<Eigen::GpuDevice, float, int>; + template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float4>; template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float2, /*conjugate=*/true>;