[TF:XLA] Avoid lowering vector data formats in Maxpool through XLA.
XLA doesn't handle these formats now, so leave them to Tensorflow to run optimized kernels on. PiperOrigin-RevId: 295793708 Change-Id: I299abebb7abd05d72b0c9d2eeea0bef20f382ce2
This commit is contained in:
parent
11b27dd35a
commit
0151f021ae
@ -32,6 +32,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/kernels/pooling_ops_common.h"
|
#include "tensorflow/core/kernels/pooling_ops_common.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -157,6 +159,13 @@ class MaxPoolOp : public PoolingOp {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
|
||||||
OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
|
OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
|
||||||
errors::InvalidArgument("Invalid data format"));
|
errors::InvalidArgument("Invalid data format"));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
data_format_ != FORMAT_NCHW_VECT_C &&
|
||||||
|
data_format_ != FORMAT_NHWC_VECT_W,
|
||||||
|
errors::Unimplemented("XLA does not support the VECT_* data formats. "
|
||||||
|
"Returning unimplemented from MaxPool to keep "
|
||||||
|
"Tensorflow's intended optimized MaxPool here."));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
|
@ -605,6 +605,10 @@ class PoolingTest(test.TestCase):
|
|||||||
use_gpu=use_gpu)
|
use_gpu=use_gpu)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
|
@test_util.xla_allow_fallback(
|
||||||
|
"Allow VECT_* data formats on newer hardware versions which XLA does not"
|
||||||
|
" handle."
|
||||||
|
)
|
||||||
def testMaxPooling(self):
|
def testMaxPooling(self):
|
||||||
for use_gpu in True, False:
|
for use_gpu in True, False:
|
||||||
self._testMaxPoolValidPadding(use_gpu)
|
self._testMaxPoolValidPadding(use_gpu)
|
||||||
|
Loading…
Reference in New Issue
Block a user