[TF2XLA] Fix the extract image patches implementation to handle more than one channel.
PiperOrigin-RevId: 248790473
This commit is contained in:
parent
373c462ee7
commit
5c27f716a7
@ -130,5 +130,20 @@ class ExtractImagePatches(xla_test.XLATestCase):
|
||||
padding="VALID",
|
||||
patches=patches)
|
||||
|
||||
def testKsize2x2Stride1x1Rate1x1ValidDepth2(self):
|
||||
"""Test for 2x2 kernel with VALID padding."""
|
||||
# [1, 2, 2, 2]
|
||||
image = [[[[1, 5], [2, 6]], [[3, 7], [4, 8]]]]
|
||||
# [1, 1, 1, 8]
|
||||
patches = [[[[1, 5, 2, 6, 3, 7, 4, 8]]]]
|
||||
self._VerifyValues(
|
||||
image,
|
||||
ksizes=[2, 2],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
patches=patches)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
@ -150,6 +151,15 @@ class ExtractImagePatchesOp : public XlaOpKernel {
|
||||
xla::XlaOp conv =
|
||||
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dims, depth);
|
||||
// Feature group convolution, will end up with the kernel_size change more
|
||||
// rapidly than the depth. Reshape, transpose and reshape to reorder them.
|
||||
auto conv_dims = builder->GetShape(conv).ValueOrDie().dimensions();
|
||||
conv_dims.back() = depth;
|
||||
conv_dims.push_back(kernel_size);
|
||||
conv = xla::TransposeInMinorDims(xla::Reshape(conv, conv_dims));
|
||||
conv_dims.pop_back();
|
||||
conv_dims.back() *= kernel_size;
|
||||
conv = xla::Reshape(conv, conv_dims);
|
||||
ctx->SetOutput(0, conv);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user