diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 67d49eafcde..5f5cae8f176 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -32,6 +32,8 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/pooling_ops_common.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -157,6 +159,13 @@ class MaxPoolOp : public PoolingOp { OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + ctx, + data_format_ != FORMAT_NCHW_VECT_C && + data_format_ != FORMAT_NHWC_VECT_W, + errors::Unimplemented("XLA does not support the VECT_* data formats. " + "Returning unimplemented from MaxPool to keep " + "Tensorflow's intended optimized MaxPool here.")); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 2e47c50acef..c9b1e42d66b 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -605,6 +605,10 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu) @test_util.run_deprecated_v1 + @test_util.xla_allow_fallback( + "Allow VECT_* data formats on newer hardware versions which XLA does not" + " handle." + ) def testMaxPooling(self): for use_gpu in True, False: self._testMaxPoolValidPadding(use_gpu)