Add Split and SplitV GPU support for complex64/complex128.
Update split_op_test to test complex numbers in addition to floats and always test with a GPU if one is available. Change: 152757648
This commit is contained in:
		
							parent
							
								
									c0095c9709
								
							
						
					
					
						commit
						df83cd08d2
					
				| @ -50,12 +50,16 @@ void SplitCustom<Device, T>::operator()( | |||||||
| #define DEFINE_GPU_KERNELS(T) template struct Split<Eigen::GpuDevice, T>; | #define DEFINE_GPU_KERNELS(T) template struct Split<Eigen::GpuDevice, T>; | ||||||
| 
 | 
 | ||||||
| TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); | TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); | ||||||
|  | TF_CALL_complex64(DEFINE_GPU_KERNELS); | ||||||
|  | TF_CALL_complex128(DEFINE_GPU_KERNELS); | ||||||
| DEFINE_GPU_KERNELS(bfloat16); | DEFINE_GPU_KERNELS(bfloat16); | ||||||
| 
 | 
 | ||||||
| #undef DEFINE_GPU_KERNELS | #undef DEFINE_GPU_KERNELS | ||||||
| #define DEFINE_GPU_KERNELS(T) template struct SplitCustom<Eigen::GpuDevice, T>; | #define DEFINE_GPU_KERNELS(T) template struct SplitCustom<Eigen::GpuDevice, T>; | ||||||
| 
 | 
 | ||||||
| TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); | TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); | ||||||
|  | TF_CALL_complex64(DEFINE_GPU_KERNELS); | ||||||
|  | TF_CALL_complex128(DEFINE_GPU_KERNELS); | ||||||
| DEFINE_GPU_KERNELS(bfloat16); | DEFINE_GPU_KERNELS(bfloat16); | ||||||
| 
 | 
 | ||||||
| #undef DEFINE_GPU_KERNELS | #undef DEFINE_GPU_KERNELS | ||||||
| @ -236,12 +240,16 @@ struct SplitVOpGPULaunch { | |||||||
| #define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch<T>; | #define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch<T>; | ||||||
| 
 | 
 | ||||||
| TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); | ||||||
|  | TF_CALL_complex64(REGISTER_GPU_KERNEL); | ||||||
|  | TF_CALL_complex128(REGISTER_GPU_KERNEL); | ||||||
| #undef REGISTER_GPU_KERNEL | #undef REGISTER_GPU_KERNEL | ||||||
| #define REGISTER_GPU_KERNEL(T)                 \ | #define REGISTER_GPU_KERNEL(T)                 \ | ||||||
|   template struct SplitVOpGPULaunch<T, int32>; \ |   template struct SplitVOpGPULaunch<T, int32>; \ | ||||||
|   template struct SplitVOpGPULaunch<T, int64>; |   template struct SplitVOpGPULaunch<T, int64>; | ||||||
| 
 | 
 | ||||||
| TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); | ||||||
|  | TF_CALL_complex64(REGISTER_GPU_KERNEL); | ||||||
|  | TF_CALL_complex128(REGISTER_GPU_KERNEL); | ||||||
| REGISTER_GPU_KERNEL(bfloat16); | REGISTER_GPU_KERNEL(bfloat16); | ||||||
| #undef REGISTER_GPU_KERNEL | #undef REGISTER_GPU_KERNEL | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -337,6 +337,8 @@ REGISTER_SPLIT(quint8); | |||||||
|                           SplitOpGPU<type>) |                           SplitOpGPU<type>) | ||||||
| 
 | 
 | ||||||
| TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); | ||||||
|  | TF_CALL_complex64(REGISTER_GPU); | ||||||
|  | TF_CALL_complex128(REGISTER_GPU); | ||||||
| #undef REGISTER_GPU | #undef REGISTER_GPU | ||||||
| 
 | 
 | ||||||
| #endif  // GOOGLE_CUDA
 | #endif  // GOOGLE_CUDA
 | ||||||
|  | |||||||
| @ -374,6 +374,8 @@ REGISTER_SPLIT_LEN(bfloat16); | |||||||
|   REGISTER_GPU(type, int64); |   REGISTER_GPU(type, int64); | ||||||
| 
 | 
 | ||||||
| TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_LEN); | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_LEN); | ||||||
|  | TF_CALL_complex64(REGISTER_GPU_LEN); | ||||||
|  | TF_CALL_complex128(REGISTER_GPU_LEN); | ||||||
| REGISTER_GPU_LEN(bfloat16); | REGISTER_GPU_LEN(bfloat16); | ||||||
| #undef REGISTER_GPU_LEN | #undef REGISTER_GPU_LEN | ||||||
| #undef REGISTER_GPU | #undef REGISTER_GPU | ||||||
|  | |||||||
| @ -28,15 +28,24 @@ from tensorflow.python.ops import gradients_impl | |||||||
| from tensorflow.python.ops import math_ops | from tensorflow.python.ops import math_ops | ||||||
| from tensorflow.python.platform import test | from tensorflow.python.platform import test | ||||||
| 
 | 
 | ||||||
|  | _TEST_DTYPES = (dtypes.float32, dtypes.float64, dtypes.complex64, | ||||||
|  |                 dtypes.complex128) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class SplitOpTest(test.TestCase): | class SplitOpTest(test.TestCase): | ||||||
| 
 | 
 | ||||||
|  |   def _makeData(self, shape, dtype): | ||||||
|  |     data = np.random.rand(*shape).astype(dtype.as_numpy_dtype) | ||||||
|  |     if dtype.is_complex: | ||||||
|  |       data -= 1j * data | ||||||
|  |     return data | ||||||
|  | 
 | ||||||
|   def testExplicitNum(self): |   def testExplicitNum(self): | ||||||
|     size_splits = array_ops.placeholder(dtype=dtypes.int32, shape=[None]) |     size_splits = array_ops.placeholder(dtype=dtypes.int32, shape=[None]) | ||||||
| 
 | 
 | ||||||
|     value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] |     value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | ||||||
| 
 | 
 | ||||||
|     with self.test_session(use_gpu=False) as sess: |     with self.test_session(use_gpu=True) as sess: | ||||||
|       with self.assertRaises(ValueError) as context: |       with self.assertRaises(ValueError) as context: | ||||||
|         sess.run(array_ops.split(value, size_splits), {size_splits: [2, 2, 6]}) |         sess.run(array_ops.split(value, size_splits), {size_splits: [2, 2, 6]}) | ||||||
| 
 | 
 | ||||||
| @ -55,13 +64,13 @@ class SplitOpTest(test.TestCase): | |||||||
| 
 | 
 | ||||||
|     value = np.random.rand(11, 11) |     value = np.random.rand(11, 11) | ||||||
| 
 | 
 | ||||||
|     with self.test_session(use_gpu=False) as sess: |     with self.test_session(use_gpu=True) as sess: | ||||||
|       result = sess.run(array_ops.split(value, [a, b])) |       result = sess.run(array_ops.split(value, [a, b])) | ||||||
| 
 | 
 | ||||||
|     self.assertAllEqual(result[0], value[0:5, :]) |     self.assertAllEqual(result[0], value[0:5, :]) | ||||||
|     self.assertAllEqual(result[1], value[5:, :]) |     self.assertAllEqual(result[1], value[5:, :]) | ||||||
| 
 | 
 | ||||||
|   def _RunAndVerifyVariable(self, use_gpu, large_num_splits=False): |   def _RunAndVerifyVariable(self, dtype, large_num_splits=False): | ||||||
|     # Random dims of rank 5 |     # Random dims of rank 5 | ||||||
|     shape = np.random.randint(1, 5, size=5) |     shape = np.random.randint(1, 5, size=5) | ||||||
|     split_dim = np.random.randint(0, 5) |     split_dim = np.random.randint(0, 5) | ||||||
| @ -71,8 +80,8 @@ class SplitOpTest(test.TestCase): | |||||||
|       num_split = np.random.randint(2, 8) |       num_split = np.random.randint(2, 8) | ||||||
|     size_splits = np.random.randint(2, 8, num_split) |     size_splits = np.random.randint(2, 8, num_split) | ||||||
|     shape[split_dim] = np.sum(size_splits) |     shape[split_dim] = np.sum(size_splits) | ||||||
|     inp = np.random.rand(*shape).astype("f") |     inp = self._makeData(shape, dtype) | ||||||
|     with self.test_session(use_gpu=use_gpu) as sess: |     with self.test_session(use_gpu=True) as sess: | ||||||
|       result = sess.run(array_ops.split(inp, size_splits, split_dim)) |       result = sess.run(array_ops.split(inp, size_splits, split_dim)) | ||||||
|     slices = [slice(0, x) for x in shape] |     slices = [slice(0, x) for x in shape] | ||||||
|     offset = 0 |     offset = 0 | ||||||
| @ -81,10 +90,10 @@ class SplitOpTest(test.TestCase): | |||||||
|       offset += size_splits[i] |       offset += size_splits[i] | ||||||
|       self.assertAllEqual(result[i], inp[slices]) |       self.assertAllEqual(result[i], inp[slices]) | ||||||
| 
 | 
 | ||||||
|   def _testSpecialCasesVariable(self, use_gpu): |   def _testSpecialCasesVariable(self): | ||||||
|     inp = np.random.rand(4, 4).astype("f") |     inp = np.random.rand(4, 4).astype("f") | ||||||
| 
 | 
 | ||||||
|     with self.test_session(use_gpu=use_gpu) as sess: |     with self.test_session(use_gpu=True) as sess: | ||||||
|       result = sess.run(array_ops.split(inp, [4], 0)) |       result = sess.run(array_ops.split(inp, [4], 0)) | ||||||
|       self.assertAllEqual(result[0], inp) |       self.assertAllEqual(result[0], inp) | ||||||
| 
 | 
 | ||||||
| @ -92,13 +101,13 @@ class SplitOpTest(test.TestCase): | |||||||
|       self.assertAllEqual(result[0], inp[0:1, :]) |       self.assertAllEqual(result[0], inp[0:1, :]) | ||||||
|       self.assertAllEqual(result[1], inp[1:4, :]) |       self.assertAllEqual(result[1], inp[1:4, :]) | ||||||
| 
 | 
 | ||||||
|   def _testHugeNumberOfTensorsVariable(self, use_gpu): |   def _testHugeNumberOfTensorsVariable(self, dtype): | ||||||
|     num_split = 10000 |     num_split = 10000 | ||||||
|     size_splits = np.random.randint(1, 3, num_split) |     size_splits = np.random.randint(1, 3, num_split) | ||||||
|     shape = [3, np.sum(size_splits)] |     shape = [3, np.sum(size_splits)] | ||||||
|     split_dim = 1 |     split_dim = 1 | ||||||
|     inp = np.random.rand(*shape).astype("f") |     inp = self._makeData(shape, dtype) | ||||||
|     with self.test_session(use_gpu=use_gpu) as sess: |     with self.test_session(use_gpu=True) as sess: | ||||||
|       result = sess.run(array_ops.split(inp, size_splits, split_dim)) |       result = sess.run(array_ops.split(inp, size_splits, split_dim)) | ||||||
|     slices = [slice(0, x) for x in shape] |     slices = [slice(0, x) for x in shape] | ||||||
|     offset = 0 |     offset = 0 | ||||||
| @ -108,18 +117,17 @@ class SplitOpTest(test.TestCase): | |||||||
|       self.assertAllEqual(result[i], inp[slices]) |       self.assertAllEqual(result[i], inp[slices]) | ||||||
| 
 | 
 | ||||||
|   def testSpecialCasesVariable(self): |   def testSpecialCasesVariable(self): | ||||||
|     self._testSpecialCasesVariable(False) |     self._testSpecialCasesVariable() | ||||||
|     self._testSpecialCasesVariable(True) |     for dtype in _TEST_DTYPES: | ||||||
|     self._testHugeNumberOfTensorsVariable(False) |       self._testHugeNumberOfTensorsVariable(dtype) | ||||||
|     self._testHugeNumberOfTensorsVariable(True) |  | ||||||
| 
 | 
 | ||||||
|   def _testGradientsSimpleVariable(self, use_gpu): |   def _testGradientsSimpleVariable(self, dtype): | ||||||
|     inp = np.random.rand(4, 4).astype("f") |     inp = self._makeData((4, 4), dtype) | ||||||
|     with self.test_session(use_gpu=use_gpu): |     with self.test_session(use_gpu=True): | ||||||
|       inp_tensor = ops.convert_to_tensor(inp) |       inp_tensor = ops.convert_to_tensor(inp) | ||||||
|       s = array_ops.split(inp_tensor, [1, 3], 1) |       s = array_ops.split(inp_tensor, [1, 3], 1) | ||||||
|       inp_grads = [ |       inp_grads = [ | ||||||
|           np.random.rand(4, 1).astype("f"), np.random.rand(4, 3).astype("f") |           self._makeData((4, 1), dtype), self._makeData((4, 3), dtype) | ||||||
|       ] |       ] | ||||||
|       grad_tensors = [constant_op.constant(x) for x in inp_grads] |       grad_tensors = [constant_op.constant(x) for x in inp_grads] | ||||||
|       grad = gradients_impl.gradients(s, [inp_tensor], grad_tensors)[-1] |       grad = gradients_impl.gradients(s, [inp_tensor], grad_tensors)[-1] | ||||||
| @ -129,16 +137,16 @@ class SplitOpTest(test.TestCase): | |||||||
|     self.assertAllEqual(result[:, 1:4], inp_grads[1]) |     self.assertAllEqual(result[:, 1:4], inp_grads[1]) | ||||||
| 
 | 
 | ||||||
|   def testOutputShape(self): |   def testOutputShape(self): | ||||||
|     with self.test_session(use_gpu=False): |     with self.test_session(use_gpu=True): | ||||||
|       tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12]) |       tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12]) | ||||||
|       size_splits = [3, 7, 2] |       size_splits = [3, 7, 2] | ||||||
|       outputs = array_ops.split(tensor, size_splits, 1) |       outputs = array_ops.split(tensor, size_splits, 1) | ||||||
|       for i, output in enumerate(outputs): |       for i, output in enumerate(outputs): | ||||||
|         self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]]) |         self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]]) | ||||||
| 
 | 
 | ||||||
|   def _compare(self, x, dim, num, use_gpu): |   def _compare(self, x, dim, num): | ||||||
|     np_ans = np.split(x, num, dim) |     np_ans = np.split(x, num, dim) | ||||||
|     with self.test_session(use_gpu=use_gpu) as sess: |     with self.test_session(use_gpu=True) as sess: | ||||||
|       tf_ans = array_ops.split(value=x, num_or_size_splits=num, axis=dim) |       tf_ans = array_ops.split(value=x, num_or_size_splits=num, axis=dim) | ||||||
|       out = sess.run(tf_ans) |       out = sess.run(tf_ans) | ||||||
|     self.assertEqual(num, len(np_ans)) |     self.assertEqual(num, len(np_ans)) | ||||||
| @ -148,21 +156,15 @@ class SplitOpTest(test.TestCase): | |||||||
|       self.assertAllEqual(np_ans[i], out[i]) |       self.assertAllEqual(np_ans[i], out[i]) | ||||||
|       self.assertShapeEqual(np_ans[i], tf_ans[i]) |       self.assertShapeEqual(np_ans[i], tf_ans[i]) | ||||||
| 
 | 
 | ||||||
|   def _testSplitRows(self, use_gpu): |   def testSplitRows(self): | ||||||
|     inp = np.random.rand(4, 4).astype("f") |     for dtype in _TEST_DTYPES: | ||||||
|     self._compare(inp, 0, 4, use_gpu) |       inp = self._makeData((4, 4), dtype) | ||||||
|  |       self._compare(inp, 0, 4) | ||||||
| 
 | 
 | ||||||
|   def testSplitRowsAll(self): |   def testSplitCols(self): | ||||||
|     self._testSplitRows(use_gpu=False) |     for dtype in _TEST_DTYPES: | ||||||
|     self._testSplitRows(use_gpu=True) |       inp = self._makeData((4, 4), dtype) | ||||||
| 
 |       self._compare(inp, 1, 4) | ||||||
|   def _testSplitCols(self, use_gpu): |  | ||||||
|     inp = np.random.rand(4, 4).astype("f") |  | ||||||
|     self._compare(inp, 1, 4, use_gpu) |  | ||||||
| 
 |  | ||||||
|   def testSplitColsAll(self): |  | ||||||
|     self._testSplitRows(use_gpu=False) |  | ||||||
|     self._testSplitCols(use_gpu=True) |  | ||||||
| 
 | 
 | ||||||
|   def _testEmpty(self, x, dim, num, expected_shape): |   def _testEmpty(self, x, dim, num, expected_shape): | ||||||
|     with self.test_session() as sess: |     with self.test_session() as sess: | ||||||
| @ -177,7 +179,8 @@ class SplitOpTest(test.TestCase): | |||||||
|   def testEmpty(self): |   def testEmpty(self): | ||||||
|     # Note: np.split returns a rank-0 empty ndarray |     # Note: np.split returns a rank-0 empty ndarray | ||||||
|     # if the input ndarray is empty. |     # if the input ndarray is empty. | ||||||
|     inp = np.random.rand(8, 0, 21).astype("f") |     for dtype in _TEST_DTYPES: | ||||||
|  |       inp = self._makeData((8, 0, 21), dtype) | ||||||
|       self._testEmpty(inp, 0, 2, (4, 0, 21)) |       self._testEmpty(inp, 0, 2, (4, 0, 21)) | ||||||
|       self._testEmpty(inp, 0, 4, (2, 0, 21)) |       self._testEmpty(inp, 0, 4, (2, 0, 21)) | ||||||
|       self._testEmpty(inp, 1, 4, (8, 0, 21)) |       self._testEmpty(inp, 1, 4, (8, 0, 21)) | ||||||
| @ -185,19 +188,19 @@ class SplitOpTest(test.TestCase): | |||||||
|       self._testEmpty(inp, 2, 7, (8, 0, 3)) |       self._testEmpty(inp, 2, 7, (8, 0, 3)) | ||||||
| 
 | 
 | ||||||
|   def testIdentity(self): |   def testIdentity(self): | ||||||
|     inp = np.random.rand(2, 2, 2).astype("f") |     for dtype in _TEST_DTYPES: | ||||||
|     for use_gpu in [False, True]: |       inp = self._makeData((2, 2, 2), dtype) | ||||||
|       self._compare(inp, 0, 1, use_gpu) |       self._compare(inp, 0, 1) | ||||||
|       self._compare(inp, 1, 1, use_gpu) |       self._compare(inp, 1, 1) | ||||||
|       self._compare(inp, 2, 1, use_gpu) |       self._compare(inp, 2, 1) | ||||||
| 
 | 
 | ||||||
|   def testSplitDim0(self): |   def testSplitDim0(self): | ||||||
|     for use_gpu in [False, True]: |     for dtype in _TEST_DTYPES: | ||||||
|       self._compare(np.random.rand(6, 10, 18).astype("f"), 0, 3, use_gpu) |       self._compare(self._makeData((6, 10, 18), dtype), 0, 3) | ||||||
|       self._compare(np.random.rand(6, 7, 18).astype("f"), 0, 3, use_gpu) |       self._compare(self._makeData((6, 7, 18), dtype), 0, 3) | ||||||
|       self._compare(np.random.rand(6, 7, 9).astype("f"), 0, 3, use_gpu) |       self._compare(self._makeData((6, 7, 9), dtype), 0, 3) | ||||||
| 
 | 
 | ||||||
|   def _RunAndVerify(self, use_gpu, large_num_splits=False): |   def _RunAndVerify(self, dtype, large_num_splits=False): | ||||||
|     # Random dims of rank 5 |     # Random dims of rank 5 | ||||||
|     shape = np.random.randint(0, 5, size=5) |     shape = np.random.randint(0, 5, size=5) | ||||||
|     split_dim = np.random.randint(0, 5) |     split_dim = np.random.randint(0, 5) | ||||||
| @ -206,8 +209,8 @@ class SplitOpTest(test.TestCase): | |||||||
|     else: |     else: | ||||||
|       num_split = np.random.randint(2, 8) |       num_split = np.random.randint(2, 8) | ||||||
|     shape[split_dim] = np.random.randint(2, 5) * num_split |     shape[split_dim] = np.random.randint(2, 5) * num_split | ||||||
|     inp = np.random.rand(*shape).astype("f") |     inp = self._makeData(shape, dtype) | ||||||
|     with self.test_session(use_gpu=use_gpu) as sess: |     with self.test_session(use_gpu=True) as sess: | ||||||
|       result = sess.run( |       result = sess.run( | ||||||
|           array_ops.split( |           array_ops.split( | ||||||
|               value=inp, num_or_size_splits=num_split, axis=split_dim)) |               value=inp, num_or_size_splits=num_split, axis=split_dim)) | ||||||
| @ -220,20 +223,19 @@ class SplitOpTest(test.TestCase): | |||||||
|       self.assertAllEqual(result[i], inp[slices]) |       self.assertAllEqual(result[i], inp[slices]) | ||||||
| 
 | 
 | ||||||
|   def testRandom(self): |   def testRandom(self): | ||||||
|  |     for dtype in _TEST_DTYPES: | ||||||
|       for _ in range(5): |       for _ in range(5): | ||||||
|       self._RunAndVerify(use_gpu=False) |         self._RunAndVerify(dtype) | ||||||
|       self._RunAndVerify(use_gpu=True) |         self._RunAndVerify(dtype, large_num_splits=True) | ||||||
|       self._RunAndVerify(use_gpu=True, large_num_splits=True) |         self._RunAndVerifyVariable(dtype) | ||||||
|       self._RunAndVerifyVariable(use_gpu=False) |         self._RunAndVerifyVariable(dtype, large_num_splits=True) | ||||||
|       self._RunAndVerifyVariable(use_gpu=True) |  | ||||||
|       self._RunAndVerifyVariable(use_gpu=True, large_num_splits=True) |  | ||||||
| 
 | 
 | ||||||
|   def _testGradientsSimple(self, use_gpu): |   def _testGradientsSimple(self, dtype): | ||||||
|     inp = np.random.rand(4, 4).astype("f") |     inp = self._makeData((4, 4), dtype) | ||||||
|     with self.test_session(use_gpu=use_gpu): |     with self.test_session(use_gpu=True): | ||||||
|       inp_tensor = ops.convert_to_tensor(inp) |       inp_tensor = ops.convert_to_tensor(inp) | ||||||
|       s = array_ops.split(value=inp_tensor, num_or_size_splits=4, axis=1) |       s = array_ops.split(value=inp_tensor, num_or_size_splits=4, axis=1) | ||||||
|       inp_grads = [np.random.rand(4, 1).astype("f") for _ in range(4)] |       inp_grads = [self._makeData((4, 1), dtype)for _ in range(4)] | ||||||
|       grad_tensors = [constant_op.constant(x) for x in inp_grads] |       grad_tensors = [constant_op.constant(x) for x in inp_grads] | ||||||
|       grad = gradients_impl.gradients(s, [inp_tensor], grad_tensors)[0] |       grad = gradients_impl.gradients(s, [inp_tensor], grad_tensors)[0] | ||||||
|       result = grad.eval() |       result = grad.eval() | ||||||
| @ -241,10 +243,9 @@ class SplitOpTest(test.TestCase): | |||||||
|       self.assertAllEqual(result[:, i:i + 1], inp_grads[i]) |       self.assertAllEqual(result[:, i:i + 1], inp_grads[i]) | ||||||
| 
 | 
 | ||||||
|   def testGradientsAll(self): |   def testGradientsAll(self): | ||||||
|     self._testGradientsSimple(use_gpu=False) |     for dtype in _TEST_DTYPES: | ||||||
|     self._testGradientsSimple(use_gpu=True) |       self._testGradientsSimple(dtype) | ||||||
|     self._testGradientsSimpleVariable(use_gpu=False) |       self._testGradientsSimpleVariable(dtype) | ||||||
|     self._testGradientsSimpleVariable(use_gpu=True) |  | ||||||
| 
 | 
 | ||||||
|   def testShapeFunctionEdgeCases(self): |   def testShapeFunctionEdgeCases(self): | ||||||
|     # split_dim greater than rank of input. |     # split_dim greater than rank of input. | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user