Enabled support for negative edge padding for XlaPad.
The underlying HLO pad already supports negative edge padding. We just removed some checks preventing negative edge padding, and added a test. PiperOrigin-RevId: 316672291 Change-Id: I38e55f86727d90f66aadb549a587b43df6588571
This commit is contained in:
parent
572442eb16
commit
2537f3d413
|
@ -211,6 +211,24 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||
[7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
|
||||
dtype=dtype))
|
||||
|
||||
def testPadNegative(self):
|
||||
for dtype in self.numeric_types:
|
||||
|
||||
def pad_fn(x):
|
||||
return xla.pad(
|
||||
x,
|
||||
padding_value=7,
|
||||
padding_low=[0, -1],
|
||||
padding_high=[1, -2],
|
||||
padding_interior=[1, 2])
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
pad_fn,
|
||||
args=(np.arange(6, dtype=np.int32).astype(dtype).reshape([2, 3]),),
|
||||
expected=np.array(
|
||||
[[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]],
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testReduce(self):
|
||||
for dtype in set(self.numeric_types).intersection(
|
||||
|
|
|
@ -64,14 +64,6 @@ class XlaPadOp : public XlaOpKernel {
|
|||
padding_interior.size(), " vs. ", rank, ")"));
|
||||
|
||||
auto non_negative = [](int64 x) { return x >= 0; };
|
||||
OP_REQUIRES(
|
||||
context, absl::c_all_of(padding_low, non_negative),
|
||||
errors::InvalidArgument("padding_low must be non-negative, got [",
|
||||
absl::StrJoin(padding_low, ","), "]"));
|
||||
OP_REQUIRES(
|
||||
context, absl::c_all_of(padding_high, non_negative),
|
||||
errors::InvalidArgument("padding_high must be non-negative, got [",
|
||||
absl::StrJoin(padding_high, ","), "]"));
|
||||
OP_REQUIRES(
|
||||
context, absl::c_all_of(padding_interior, non_negative),
|
||||
errors::InvalidArgument("padding_interior must be non-negative, got [",
|
||||
|
|
Loading…
Reference in New Issue