diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 14bfe7d82ab..5fabf000de8 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -326,14 +326,19 @@ class SelectOpTest(test.TestCase): self.assertAllEqual(np_ans, tf_ans) self.assertShapeEqual(np_ans, out) - def _compareGradientX(self, fn, c, x, y, numeric_gradient_type=None): + def _compareGradientX(self, fn, c, x, y, numeric_gradient_type=None, + x_init_value=None): with self.cached_session(): inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) out = fn(c, inx, iny) s = list(np.shape(c)) + if x_init_value is None: + x_init_value = x + if x.shape != y.shape: + x_init_value = np.broadcast_to(y, x.shape) jacob_t, jacob_n = gradient_checker.compute_gradient( - inx, s, out, s, x_init_value=x) + inx, s, out, s, x_init_value=x_init_value) if numeric_gradient_type is not None: xf = x.astype(numeric_gradient_type) yf = y.astype(numeric_gradient_type) @@ -357,7 +362,7 @@ class SelectOpTest(test.TestCase): out = fn(c, inx, iny) s = list(np.shape(c)) jacob_t, jacob_n = gradient_checker.compute_gradient( - iny, s, out, s, x_init_value=y, delta=1.0) + iny, s, out, s, x_init_value=x, delta=1.0) if numeric_gradient_type is not None: xf = x.astype(numeric_gradient_type) yf = y.astype(numeric_gradient_type) @@ -392,6 +397,49 @@ class SelectOpTest(test.TestCase): self._testScalar(array_ops.where) self._testScalar(array_ops.where_v2) + def _testScalarBroadcast(self, fn, c, x, y): + for t in [ + np.float16, np.float32, np.float64, np.int32, np.int64, np.complex64, + np.complex128 + ]: + xt = x.astype(t) + yt = y.astype(t) + self._compare(fn, c, xt, yt, use_gpu=False) + if t in [np.float16, np.float32, np.float64]: + self._compare(fn, c, xt, yt, use_gpu=True) + + def testScalarBroadcast(self): + c = True + # where_v2 only + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 1) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 3, 1) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 2) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 2) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(3, 2) * 100 + self._testScalarBroadcast(array_ops.where_v2, c, x, y) + self._testScalarBroadcast(array_ops.where_v2, c, y, x) + def _testBasic(self, fn): c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) x = np.random.rand(1, 3, 2) * 100 @@ -410,6 +458,53 @@ class SelectOpTest(test.TestCase): self._testBasic(array_ops.where) self._testBasic(array_ops.where_v2) + def _testBasicBroadcast(self, fn, c, x, y): + for t in [ + np.float16, np.float32, np.float64, np.int32, np.int64, np.complex64, + np.complex128 + ]: + xt = x.astype(t) + yt = y.astype(t) + self._compare(fn, c, xt, yt, use_gpu=False) + if t in [np.float16, np.float32, np.float64]: + self._compare(fn, c, xt, yt, use_gpu=True) + + def testBasicBroadcast(self): + c0 = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) + c1 = np.random.randint(0, 2, 2).astype(np.bool).reshape(1, 1, 2) + c2 = np.random.randint(0, 2, 3).astype(np.bool).reshape(1, 3, 1) + c3 = np.random.randint(0, 2, 1).astype(np.bool).reshape(1, 1, 1) + for c in [c0, c1, c2, c3]: + # where_v2 only + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 1) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 3, 1) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 2) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 2) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(3, 2) * 100 + self._testBasicBroadcast(array_ops.where_v2, c, x, y) + self._testBasicBroadcast(array_ops.where_v2, c, y, x) + def _testGradients(self, fn): c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) x = np.random.rand(1, 3, 2) * 100 @@ -434,6 +529,33 @@ class SelectOpTest(test.TestCase): self._testGradients(array_ops.where) self._testGradients(array_ops.where_v2) + @test_util.run_deprecated_v1 + def testGradientsBroadcast(self): + c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) + for t in [np.float32, np.float64]: + # where_v2 only + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 1) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 3, 1) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1, 2) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 1) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(1, 2) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + x = np.random.rand(1, 3, 2) * 100 + y = np.random.rand(3, 2) * 100 + self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t)) + def _testShapeMismatch(self, fn): c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2) x = np.random.rand(1, 3, 2) * 100