Enabled support for negative edge padding for XlaPad.

The underlying HLO pad already supports negative edge padding. We just removed
some checks preventing negative edge padding, and added a test.

PiperOrigin-RevId: 316672291
Change-Id: I38e55f86727d90f66aadb549a587b43df6588571
This commit is contained in:
A. Unique TensorFlower 2020-06-16 06:56:09 -07:00 committed by TensorFlower Gardener
parent 572442eb16
commit 2537f3d413
2 changed files with 18 additions and 8 deletions

View File

@ -211,6 +211,24 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
[7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]], [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
dtype=dtype)) dtype=dtype))
def testPadNegative(self):
for dtype in self.numeric_types:
def pad_fn(x):
return xla.pad(
x,
padding_value=7,
padding_low=[0, -1],
padding_high=[1, -2],
padding_interior=[1, 2])
self._assertOpOutputMatchesExpected(
pad_fn,
args=(np.arange(6, dtype=np.int32).astype(dtype).reshape([2, 3]),),
expected=np.array(
[[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]],
dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet') @test_util.disable_mlir_bridge('Not supported yet')
def testReduce(self): def testReduce(self):
for dtype in set(self.numeric_types).intersection( for dtype in set(self.numeric_types).intersection(

View File

@ -64,14 +64,6 @@ class XlaPadOp : public XlaOpKernel {
padding_interior.size(), " vs. ", rank, ")")); padding_interior.size(), " vs. ", rank, ")"));
auto non_negative = [](int64 x) { return x >= 0; }; auto non_negative = [](int64 x) { return x >= 0; };
OP_REQUIRES(
context, absl::c_all_of(padding_low, non_negative),
errors::InvalidArgument("padding_low must be non-negative, got [",
absl::StrJoin(padding_low, ","), "]"));
OP_REQUIRES(
context, absl::c_all_of(padding_high, non_negative),
errors::InvalidArgument("padding_high must be non-negative, got [",
absl::StrJoin(padding_high, ","), "]"));
OP_REQUIRES( OP_REQUIRES(
context, absl::c_all_of(padding_interior, non_negative), context, absl::c_all_of(padding_interior, non_negative),
errors::InvalidArgument("padding_interior must be non-negative, got [", errors::InvalidArgument("padding_interior must be non-negative, got [",