[TF2XLA] Fix the extract image patches implementation to handle more than one channel.
PiperOrigin-RevId: 248790473
This commit is contained in:
parent
373c462ee7
commit
5c27f716a7
@ -130,5 +130,20 @@ class ExtractImagePatches(xla_test.XLATestCase):
|
|||||||
padding="VALID",
|
padding="VALID",
|
||||||
patches=patches)
|
patches=patches)
|
||||||
|
|
||||||
|
def testKsize2x2Stride1x1Rate1x1ValidDepth2(self):
|
||||||
|
"""Test for 2x2 kernel with VALID padding."""
|
||||||
|
# [1, 2, 2, 2]
|
||||||
|
image = [[[[1, 5], [2, 6]], [[3, 7], [4, 8]]]]
|
||||||
|
# [1, 1, 1, 8]
|
||||||
|
patches = [[[[1, 5, 2, 6, 3, 7, 4, 8]]]]
|
||||||
|
self._VerifyValues(
|
||||||
|
image,
|
||||||
|
ksizes=[2, 2],
|
||||||
|
strides=[1, 1],
|
||||||
|
rates=[1, 1],
|
||||||
|
padding="VALID",
|
||||||
|
patches=patches)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
@ -150,6 +151,15 @@ class ExtractImagePatchesOp : public XlaOpKernel {
|
|||||||
xla::XlaOp conv =
|
xla::XlaOp conv =
|
||||||
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
|
xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
|
||||||
lhs_dilation, rhs_dilation, dims, depth);
|
lhs_dilation, rhs_dilation, dims, depth);
|
||||||
|
// Feature group convolution, will end up with the kernel_size change more
|
||||||
|
// rapidly than the depth. Reshape, transpose and reshape to reorder them.
|
||||||
|
auto conv_dims = builder->GetShape(conv).ValueOrDie().dimensions();
|
||||||
|
conv_dims.back() = depth;
|
||||||
|
conv_dims.push_back(kernel_size);
|
||||||
|
conv = xla::TransposeInMinorDims(xla::Reshape(conv, conv_dims));
|
||||||
|
conv_dims.pop_back();
|
||||||
|
conv_dims.back() *= kernel_size;
|
||||||
|
conv = xla::Reshape(conv, conv_dims);
|
||||||
ctx->SetOutput(0, conv);
|
ctx->SetOutput(0, conv);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user