[TF:XLA] Avoid lowering vector data formats in Maxpool through XLA.

XLA doesn't handle these formats now, so leave them to Tensorflow to run optimized kernels on.

PiperOrigin-RevId: 295793708
Change-Id: I299abebb7abd05d72b0c9d2eeea0bef20f382ce2
This commit is contained in:
Tres Popp 2020-02-18 12:17:40 -08:00 committed by TensorFlower Gardener
parent 11b27dd35a
commit 0151f021ae
2 changed files with 13 additions and 0 deletions

View File

@ -32,6 +32,8 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/pooling_ops_common.h" #include "tensorflow/core/kernels/pooling_ops_common.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -157,6 +159,13 @@ class MaxPoolOp : public PoolingOp {
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format")); errors::InvalidArgument("Invalid data format"));
OP_REQUIRES(
ctx,
data_format_ != FORMAT_NCHW_VECT_C &&
data_format_ != FORMAT_NHWC_VECT_W,
errors::Unimplemented("XLA does not support the VECT_* data formats. "
"Returning unimplemented from MaxPool to keep "
"Tensorflow's intended optimized MaxPool here."));
} }
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {

View File

@ -605,6 +605,10 @@ class PoolingTest(test.TestCase):
use_gpu=use_gpu) use_gpu=use_gpu)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.xla_allow_fallback(
"Allow VECT_* data formats on newer hardware versions which XLA does not"
" handle."
)
def testMaxPooling(self): def testMaxPooling(self):
for use_gpu in True, False: for use_gpu in True, False:
self._testMaxPoolValidPadding(use_gpu) self._testMaxPoolValidPadding(use_gpu)