Enabled support for negative edge padding for XlaPad.
The underlying HLO pad already supports negative edge padding. We just removed some checks preventing negative edge padding, and added a test. PiperOrigin-RevId: 316672291 Change-Id: I38e55f86727d90f66aadb549a587b43df6588571
This commit is contained in:
parent
572442eb16
commit
2537f3d413
|
@ -211,6 +211,24 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||||
[7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
|
[7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
|
def testPadNegative(self):
|
||||||
|
for dtype in self.numeric_types:
|
||||||
|
|
||||||
|
def pad_fn(x):
|
||||||
|
return xla.pad(
|
||||||
|
x,
|
||||||
|
padding_value=7,
|
||||||
|
padding_low=[0, -1],
|
||||||
|
padding_high=[1, -2],
|
||||||
|
padding_interior=[1, 2])
|
||||||
|
|
||||||
|
self._assertOpOutputMatchesExpected(
|
||||||
|
pad_fn,
|
||||||
|
args=(np.arange(6, dtype=np.int32).astype(dtype).reshape([2, 3]),),
|
||||||
|
expected=np.array(
|
||||||
|
[[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]],
|
||||||
|
dtype=dtype))
|
||||||
|
|
||||||
@test_util.disable_mlir_bridge('Not supported yet')
|
@test_util.disable_mlir_bridge('Not supported yet')
|
||||||
def testReduce(self):
|
def testReduce(self):
|
||||||
for dtype in set(self.numeric_types).intersection(
|
for dtype in set(self.numeric_types).intersection(
|
||||||
|
|
|
@ -64,14 +64,6 @@ class XlaPadOp : public XlaOpKernel {
|
||||||
padding_interior.size(), " vs. ", rank, ")"));
|
padding_interior.size(), " vs. ", rank, ")"));
|
||||||
|
|
||||||
auto non_negative = [](int64 x) { return x >= 0; };
|
auto non_negative = [](int64 x) { return x >= 0; };
|
||||||
OP_REQUIRES(
|
|
||||||
context, absl::c_all_of(padding_low, non_negative),
|
|
||||||
errors::InvalidArgument("padding_low must be non-negative, got [",
|
|
||||||
absl::StrJoin(padding_low, ","), "]"));
|
|
||||||
OP_REQUIRES(
|
|
||||||
context, absl::c_all_of(padding_high, non_negative),
|
|
||||||
errors::InvalidArgument("padding_high must be non-negative, got [",
|
|
||||||
absl::StrJoin(padding_high, ","), "]"));
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, absl::c_all_of(padding_interior, non_negative),
|
context, absl::c_all_of(padding_interior, non_negative),
|
||||||
errors::InvalidArgument("padding_interior must be non-negative, got [",
|
errors::InvalidArgument("padding_interior must be non-negative, got [",
|
||||||
|
|
Loading…
Reference in New Issue