[TF:XLA] Avoid lowering vector data formats in Maxpool through XLA.
XLA doesn't handle these formats now, so leave them to Tensorflow to run optimized kernels on. PiperOrigin-RevId: 295793708 Change-Id: I299abebb7abd05d72b0c9d2eeea0bef20f382ce2
This commit is contained in:
parent
11b27dd35a
commit
0151f021ae
@ -32,6 +32,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/pooling_ops_common.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -157,6 +159,13 @@ class MaxPoolOp : public PoolingOp {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
|
||||
OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
data_format_ != FORMAT_NCHW_VECT_C &&
|
||||
data_format_ != FORMAT_NHWC_VECT_W,
|
||||
errors::Unimplemented("XLA does not support the VECT_* data formats. "
|
||||
"Returning unimplemented from MaxPool to keep "
|
||||
"Tensorflow's intended optimized MaxPool here."));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
|
@ -605,6 +605,10 @@ class PoolingTest(test.TestCase):
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.xla_allow_fallback(
|
||||
"Allow VECT_* data formats on newer hardware versions which XLA does not"
|
||||
" handle."
|
||||
)
|
||||
def testMaxPooling(self):
|
||||
for use_gpu in True, False:
|
||||
self._testMaxPoolValidPadding(use_gpu)
|
||||
|
Loading…
Reference in New Issue
Block a user