Remove illegal BroadcastTo op compiler tests
BroadcastTo op requires input shape to be broadcast compatible with the required shape and can't modify dimensions of size greater than one. Added couple of legal tests to improve coverage. These were failing in shape inference function and then failing to get lowered in the MLIR bridge. PiperOrigin-RevId: 312696176 Change-Id: I42a85618b8bbf6ff9dce46de01e6ad3b319a269f
This commit is contained in:
parent
bfe0b28c37
commit
c2534e2336
@ -1579,8 +1579,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([4, 5, 6], dtype=np.int32),
|
||||
expected=None)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Requires BroadcastInDim method in MlirHloBuilder")
|
||||
def testBroadcastTo(self):
|
||||
for dtype in self.all_types:
|
||||
x = np.random.randint(0, high=100, size=[2, 3])
|
||||
@ -1591,29 +1589,16 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
expected=x)
|
||||
self._testBinary(
|
||||
array_ops.broadcast_to,
|
||||
x,
|
||||
np.array([6, 6], dtype=np.int32),
|
||||
expected=np.tile(x, [3, 2]))
|
||||
np.zeros([2, 3], dtype=dtype),
|
||||
np.array([2, 2, 3], dtype=np.int32),
|
||||
expected=np.zeros([2, 2, 3], dtype=dtype))
|
||||
|
||||
x = np.arange(2).reshape((2, 1)).astype(dtype)
|
||||
self._testBinary(
|
||||
array_ops.broadcast_to,
|
||||
x,
|
||||
np.array([7, 4, 3], dtype=np.int32),
|
||||
expected=np.tile(x, [7, 2, 1]))
|
||||
self._testBinary(
|
||||
array_ops.broadcast_to,
|
||||
x,
|
||||
np.array([7, 0, 3], dtype=np.int32),
|
||||
expected=np.zeros([7, 0, 3], dtype=dtype))
|
||||
self._testBinary(
|
||||
array_ops.broadcast_to,
|
||||
x,
|
||||
np.array([7, 1, 2, 9], dtype=np.int32),
|
||||
expected=np.tile(x, [7, 1, 1, 3]))
|
||||
self._testBinary(
|
||||
array_ops.broadcast_to,
|
||||
np.zeros([2, 0], dtype=dtype),
|
||||
np.array([4, 0], dtype=np.int32),
|
||||
expected=np.zeros([4, 0], dtype=dtype))
|
||||
np.array([2, 2, 3], dtype=np.int32),
|
||||
expected=np.tile(x, (2, 1, 3)))
|
||||
|
||||
x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype)
|
||||
self._testBinary(
|
||||
|
Loading…
x
Reference in New Issue
Block a user