[TF2XLA] Fix the extract image patches implementation to handle more than one channel.

PiperOrigin-RevId: 248790473
This commit is contained in:
Blake Hechtman 2019-05-17 14:41:04 -07:00 committed by TensorFlower Gardener
parent 373c462ee7
commit 5c27f716a7
2 changed files with 25 additions and 0 deletions

View File

@ -130,5 +130,20 @@ class ExtractImagePatches(xla_test.XLATestCase):
padding="VALID",
patches=patches)
def testKsize2x2Stride1x1Rate1x1ValidDepth2(self):
"""Test for 2x2 kernel with VALID padding."""
# [1, 2, 2, 2]
image = [[[[1, 5], [2, 6]], [[3, 7], [4, 8]]]]
# [1, 1, 1, 8]
patches = [[[[1, 5, 2, 6, 3, 7, 4, 8]]]]
self._VerifyValues(
image,
ksizes=[2, 2],
strides=[1, 1],
rates=[1, 1],
padding="VALID",
patches=patches)
if __name__ == "__main__":
test.main()

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/util/tensor_format.h"
@ -150,6 +151,15 @@ class ExtractImagePatchesOp : public XlaOpKernel {
xla::XlaOp conv =
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
lhs_dilation, rhs_dilation, dims, depth);
// Feature group convolution, will end up with the kernel_size change more
// rapidly than the depth. Reshape, transpose and reshape to reorder them.
auto conv_dims = builder->GetShape(conv).ValueOrDie().dimensions();
conv_dims.back() = depth;
conv_dims.push_back(kernel_size);
conv = xla::TransposeInMinorDims(xla::Reshape(conv, conv_dims));
conv_dims.pop_back();
conv_dims.back() *= kernel_size;
conv = xla::Reshape(conv, conv_dims);
ctx->SetOutput(0, conv);
}