Added additional gradient tests for where_v2
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
518a2bad14
commit
0566e47d73
@ -326,14 +326,19 @@ class SelectOpTest(test.TestCase):
|
||||
self.assertAllEqual(np_ans, tf_ans)
|
||||
self.assertShapeEqual(np_ans, out)
|
||||
|
||||
def _compareGradientX(self, fn, c, x, y, numeric_gradient_type=None):
|
||||
def _compareGradientX(self, fn, c, x, y, numeric_gradient_type=None,
|
||||
x_init_value=None):
|
||||
with self.cached_session():
|
||||
inx = ops.convert_to_tensor(x)
|
||||
iny = ops.convert_to_tensor(y)
|
||||
out = fn(c, inx, iny)
|
||||
s = list(np.shape(c))
|
||||
if x_init_value is None:
|
||||
x_init_value = x
|
||||
if x.shape != y.shape:
|
||||
x_init_value = np.broadcast_to(y, x.shape)
|
||||
jacob_t, jacob_n = gradient_checker.compute_gradient(
|
||||
inx, s, out, s, x_init_value=x)
|
||||
inx, s, out, s, x_init_value=x_init_value)
|
||||
if numeric_gradient_type is not None:
|
||||
xf = x.astype(numeric_gradient_type)
|
||||
yf = y.astype(numeric_gradient_type)
|
||||
@ -357,7 +362,7 @@ class SelectOpTest(test.TestCase):
|
||||
out = fn(c, inx, iny)
|
||||
s = list(np.shape(c))
|
||||
jacob_t, jacob_n = gradient_checker.compute_gradient(
|
||||
iny, s, out, s, x_init_value=y, delta=1.0)
|
||||
iny, s, out, s, x_init_value=x, delta=1.0)
|
||||
if numeric_gradient_type is not None:
|
||||
xf = x.astype(numeric_gradient_type)
|
||||
yf = y.astype(numeric_gradient_type)
|
||||
@ -392,6 +397,49 @@ class SelectOpTest(test.TestCase):
|
||||
self._testScalar(array_ops.where)
|
||||
self._testScalar(array_ops.where_v2)
|
||||
|
||||
def _testScalarBroadcast(self, fn, c, x, y):
|
||||
for t in [
|
||||
np.float16, np.float32, np.float64, np.int32, np.int64, np.complex64,
|
||||
np.complex128
|
||||
]:
|
||||
xt = x.astype(t)
|
||||
yt = y.astype(t)
|
||||
self._compare(fn, c, xt, yt, use_gpu=False)
|
||||
if t in [np.float16, np.float32, np.float64]:
|
||||
self._compare(fn, c, xt, yt, use_gpu=True)
|
||||
|
||||
def testScalarBroadcast(self):
|
||||
c = True
|
||||
# where_v2 only
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 1, 1) * 100
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 3, 1) * 100
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 1, 2) * 100
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 1) * 100
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1) * 100
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 2) * 100
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(3, 2) * 100
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testScalarBroadcast(array_ops.where_v2, c, y, x)
|
||||
|
||||
def _testBasic(self, fn):
|
||||
c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
@ -410,6 +458,53 @@ class SelectOpTest(test.TestCase):
|
||||
self._testBasic(array_ops.where)
|
||||
self._testBasic(array_ops.where_v2)
|
||||
|
||||
def _testBasicBroadcast(self, fn, c, x, y):
|
||||
for t in [
|
||||
np.float16, np.float32, np.float64, np.int32, np.int64, np.complex64,
|
||||
np.complex128
|
||||
]:
|
||||
xt = x.astype(t)
|
||||
yt = y.astype(t)
|
||||
self._compare(fn, c, xt, yt, use_gpu=False)
|
||||
if t in [np.float16, np.float32, np.float64]:
|
||||
self._compare(fn, c, xt, yt, use_gpu=True)
|
||||
|
||||
def testBasicBroadcast(self):
|
||||
c0 = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
|
||||
c1 = np.random.randint(0, 2, 2).astype(np.bool).reshape(1, 1, 2)
|
||||
c2 = np.random.randint(0, 2, 3).astype(np.bool).reshape(1, 3, 1)
|
||||
c3 = np.random.randint(0, 2, 1).astype(np.bool).reshape(1, 1, 1)
|
||||
for c in [c0, c1, c2, c3]:
|
||||
# where_v2 only
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 1, 1) * 100
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 3, 1) * 100
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 1, 2) * 100
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 1) * 100
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1) * 100
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 2) * 100
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(3, 2) * 100
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, x, y)
|
||||
self._testBasicBroadcast(array_ops.where_v2, c, y, x)
|
||||
|
||||
def _testGradients(self, fn):
|
||||
c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
@ -434,6 +529,33 @@ class SelectOpTest(test.TestCase):
|
||||
self._testGradients(array_ops.where)
|
||||
self._testGradients(array_ops.where_v2)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientsBroadcast(self):
|
||||
c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
|
||||
for t in [np.float32, np.float64]:
|
||||
# where_v2 only
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 1, 1) * 100
|
||||
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 3, 1) * 100
|
||||
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 1, 2) * 100
|
||||
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 1) * 100
|
||||
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1) * 100
|
||||
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(1, 2) * 100
|
||||
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
y = np.random.rand(3, 2) * 100
|
||||
self._compareGradientX(array_ops.where_v2, c, x.astype(t), y.astype(t))
|
||||
|
||||
def _testShapeMismatch(self, fn):
|
||||
c = np.random.randint(0, 2, 6).astype(np.bool).reshape(1, 3, 2)
|
||||
x = np.random.rand(1, 3, 2) * 100
|
||||
|
Loading…
Reference in New Issue
Block a user