diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 00ed6d83e2e..c7be2c55de7 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -1579,8 +1579,6 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) - @test_util.disable_mlir_bridge( - "Requires BroadcastInDim method in MlirHloBuilder") def testBroadcastTo(self): for dtype in self.all_types: x = np.random.randint(0, high=100, size=[2, 3]) @@ -1591,29 +1589,16 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=x) self._testBinary( array_ops.broadcast_to, - x, - np.array([6, 6], dtype=np.int32), - expected=np.tile(x, [3, 2])) + np.zeros([2, 3], dtype=dtype), + np.array([2, 2, 3], dtype=np.int32), + expected=np.zeros([2, 2, 3], dtype=dtype)) + + x = np.arange(2).reshape((2, 1)).astype(dtype) self._testBinary( array_ops.broadcast_to, x, - np.array([7, 4, 3], dtype=np.int32), - expected=np.tile(x, [7, 2, 1])) - self._testBinary( - array_ops.broadcast_to, - x, - np.array([7, 0, 3], dtype=np.int32), - expected=np.zeros([7, 0, 3], dtype=dtype)) - self._testBinary( - array_ops.broadcast_to, - x, - np.array([7, 1, 2, 9], dtype=np.int32), - expected=np.tile(x, [7, 1, 1, 3])) - self._testBinary( - array_ops.broadcast_to, - np.zeros([2, 0], dtype=dtype), - np.array([4, 0], dtype=np.int32), - expected=np.zeros([4, 0], dtype=dtype)) + np.array([2, 2, 3], dtype=np.int32), + expected=np.tile(x, (2, 1, 3))) x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype) self._testBinary(