Add Split and SplitV GPU support for complex64/complex128.

Update split_op_test to test complex numbers in addition to floats and always test with a GPU if one is available.
Change: 152757648
This commit is contained in:
RJ Ryan 2017-04-10 16:31:02 -08:00 committed by TensorFlower Gardener
parent c0095c9709
commit df83cd08d2
4 changed files with 81 additions and 68 deletions

View File

@ -50,12 +50,16 @@ void SplitCustom<Device, T>::operator()(
#define DEFINE_GPU_KERNELS(T) template struct Split<Eigen::GpuDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_complex64(DEFINE_GPU_KERNELS);
TF_CALL_complex128(DEFINE_GPU_KERNELS);
DEFINE_GPU_KERNELS(bfloat16);
#undef DEFINE_GPU_KERNELS
#define DEFINE_GPU_KERNELS(T) template struct SplitCustom<Eigen::GpuDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_complex64(DEFINE_GPU_KERNELS);
TF_CALL_complex128(DEFINE_GPU_KERNELS);
DEFINE_GPU_KERNELS(bfloat16);
#undef DEFINE_GPU_KERNELS
@ -236,12 +240,16 @@ struct SplitVOpGPULaunch {
#define REGISTER_GPU_KERNEL(T) template struct SplitOpGPULaunch<T>;
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
TF_CALL_complex64(REGISTER_GPU_KERNEL);
TF_CALL_complex128(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
#define REGISTER_GPU_KERNEL(T) \
template struct SplitVOpGPULaunch<T, int32>; \
template struct SplitVOpGPULaunch<T, int64>;
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
TF_CALL_complex64(REGISTER_GPU_KERNEL);
TF_CALL_complex128(REGISTER_GPU_KERNEL);
REGISTER_GPU_KERNEL(bfloat16);
#undef REGISTER_GPU_KERNEL

View File

@ -337,6 +337,8 @@ REGISTER_SPLIT(quint8);
SplitOpGPU<type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
#undef REGISTER_GPU
#endif // GOOGLE_CUDA

View File

@ -374,6 +374,8 @@ REGISTER_SPLIT_LEN(bfloat16);
REGISTER_GPU(type, int64);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_LEN);
TF_CALL_complex64(REGISTER_GPU_LEN);
TF_CALL_complex128(REGISTER_GPU_LEN);
REGISTER_GPU_LEN(bfloat16);
#undef REGISTER_GPU_LEN
#undef REGISTER_GPU

View File

@ -28,15 +28,24 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
_TEST_DTYPES = (dtypes.float32, dtypes.float64, dtypes.complex64,
dtypes.complex128)
class SplitOpTest(test.TestCase):
def _makeData(self, shape, dtype):
data = np.random.rand(*shape).astype(dtype.as_numpy_dtype)
if dtype.is_complex:
data -= 1j * data
return data
def testExplicitNum(self):
size_splits = array_ops.placeholder(dtype=dtypes.int32, shape=[None])
value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
with self.test_session(use_gpu=False) as sess:
with self.test_session(use_gpu=True) as sess:
with self.assertRaises(ValueError) as context:
sess.run(array_ops.split(value, size_splits), {size_splits: [2, 2, 6]})
@ -55,13 +64,13 @@ class SplitOpTest(test.TestCase):
value = np.random.rand(11, 11)
with self.test_session(use_gpu=False) as sess:
with self.test_session(use_gpu=True) as sess:
result = sess.run(array_ops.split(value, [a, b]))
self.assertAllEqual(result[0], value[0:5, :])
self.assertAllEqual(result[1], value[5:, :])
def _RunAndVerifyVariable(self, use_gpu, large_num_splits=False):
def _RunAndVerifyVariable(self, dtype, large_num_splits=False):
# Random dims of rank 5
shape = np.random.randint(1, 5, size=5)
split_dim = np.random.randint(0, 5)
@ -71,8 +80,8 @@ class SplitOpTest(test.TestCase):
num_split = np.random.randint(2, 8)
size_splits = np.random.randint(2, 8, num_split)
shape[split_dim] = np.sum(size_splits)
inp = np.random.rand(*shape).astype("f")
with self.test_session(use_gpu=use_gpu) as sess:
inp = self._makeData(shape, dtype)
with self.test_session(use_gpu=True) as sess:
result = sess.run(array_ops.split(inp, size_splits, split_dim))
slices = [slice(0, x) for x in shape]
offset = 0
@ -81,10 +90,10 @@ class SplitOpTest(test.TestCase):
offset += size_splits[i]
self.assertAllEqual(result[i], inp[slices])
def _testSpecialCasesVariable(self, use_gpu):
def _testSpecialCasesVariable(self):
inp = np.random.rand(4, 4).astype("f")
with self.test_session(use_gpu=use_gpu) as sess:
with self.test_session(use_gpu=True) as sess:
result = sess.run(array_ops.split(inp, [4], 0))
self.assertAllEqual(result[0], inp)
@ -92,13 +101,13 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(result[0], inp[0:1, :])
self.assertAllEqual(result[1], inp[1:4, :])
def _testHugeNumberOfTensorsVariable(self, use_gpu):
def _testHugeNumberOfTensorsVariable(self, dtype):
num_split = 10000
size_splits = np.random.randint(1, 3, num_split)
shape = [3, np.sum(size_splits)]
split_dim = 1
inp = np.random.rand(*shape).astype("f")
with self.test_session(use_gpu=use_gpu) as sess:
inp = self._makeData(shape, dtype)
with self.test_session(use_gpu=True) as sess:
result = sess.run(array_ops.split(inp, size_splits, split_dim))
slices = [slice(0, x) for x in shape]
offset = 0
@ -108,18 +117,17 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(result[i], inp[slices])
def testSpecialCasesVariable(self):
self._testSpecialCasesVariable(False)
self._testSpecialCasesVariable(True)
self._testHugeNumberOfTensorsVariable(False)
self._testHugeNumberOfTensorsVariable(True)
self._testSpecialCasesVariable()
for dtype in _TEST_DTYPES:
self._testHugeNumberOfTensorsVariable(dtype)
def _testGradientsSimpleVariable(self, use_gpu):
inp = np.random.rand(4, 4).astype("f")
with self.test_session(use_gpu=use_gpu):
def _testGradientsSimpleVariable(self, dtype):
inp = self._makeData((4, 4), dtype)
with self.test_session(use_gpu=True):
inp_tensor = ops.convert_to_tensor(inp)
s = array_ops.split(inp_tensor, [1, 3], 1)
inp_grads = [
np.random.rand(4, 1).astype("f"), np.random.rand(4, 3).astype("f")
self._makeData((4, 1), dtype), self._makeData((4, 3), dtype)
]
grad_tensors = [constant_op.constant(x) for x in inp_grads]
grad = gradients_impl.gradients(s, [inp_tensor], grad_tensors)[-1]
@ -129,16 +137,16 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(result[:, 1:4], inp_grads[1])
def testOutputShape(self):
with self.test_session(use_gpu=False):
with self.test_session(use_gpu=True):
tensor = array_ops.placeholder(dtypes.float32, shape=[None, 12])
size_splits = [3, 7, 2]
outputs = array_ops.split(tensor, size_splits, 1)
for i, output in enumerate(outputs):
self.assertEqual(output.get_shape().as_list(), [None, size_splits[i]])
def _compare(self, x, dim, num, use_gpu):
def _compare(self, x, dim, num):
np_ans = np.split(x, num, dim)
with self.test_session(use_gpu=use_gpu) as sess:
with self.test_session(use_gpu=True) as sess:
tf_ans = array_ops.split(value=x, num_or_size_splits=num, axis=dim)
out = sess.run(tf_ans)
self.assertEqual(num, len(np_ans))
@ -148,21 +156,15 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(np_ans[i], out[i])
self.assertShapeEqual(np_ans[i], tf_ans[i])
def _testSplitRows(self, use_gpu):
inp = np.random.rand(4, 4).astype("f")
self._compare(inp, 0, 4, use_gpu)
def testSplitRows(self):
for dtype in _TEST_DTYPES:
inp = self._makeData((4, 4), dtype)
self._compare(inp, 0, 4)
def testSplitRowsAll(self):
self._testSplitRows(use_gpu=False)
self._testSplitRows(use_gpu=True)
def _testSplitCols(self, use_gpu):
inp = np.random.rand(4, 4).astype("f")
self._compare(inp, 1, 4, use_gpu)
def testSplitColsAll(self):
self._testSplitRows(use_gpu=False)
self._testSplitCols(use_gpu=True)
def testSplitCols(self):
for dtype in _TEST_DTYPES:
inp = self._makeData((4, 4), dtype)
self._compare(inp, 1, 4)
def _testEmpty(self, x, dim, num, expected_shape):
with self.test_session() as sess:
@ -177,27 +179,28 @@ class SplitOpTest(test.TestCase):
def testEmpty(self):
# Note: np.split returns a rank-0 empty ndarray
# if the input ndarray is empty.
inp = np.random.rand(8, 0, 21).astype("f")
self._testEmpty(inp, 0, 2, (4, 0, 21))
self._testEmpty(inp, 0, 4, (2, 0, 21))
self._testEmpty(inp, 1, 4, (8, 0, 21))
self._testEmpty(inp, 2, 3, (8, 0, 7))
self._testEmpty(inp, 2, 7, (8, 0, 3))
for dtype in _TEST_DTYPES:
inp = self._makeData((8, 0, 21), dtype)
self._testEmpty(inp, 0, 2, (4, 0, 21))
self._testEmpty(inp, 0, 4, (2, 0, 21))
self._testEmpty(inp, 1, 4, (8, 0, 21))
self._testEmpty(inp, 2, 3, (8, 0, 7))
self._testEmpty(inp, 2, 7, (8, 0, 3))
def testIdentity(self):
inp = np.random.rand(2, 2, 2).astype("f")
for use_gpu in [False, True]:
self._compare(inp, 0, 1, use_gpu)
self._compare(inp, 1, 1, use_gpu)
self._compare(inp, 2, 1, use_gpu)
for dtype in _TEST_DTYPES:
inp = self._makeData((2, 2, 2), dtype)
self._compare(inp, 0, 1)
self._compare(inp, 1, 1)
self._compare(inp, 2, 1)
def testSplitDim0(self):
for use_gpu in [False, True]:
self._compare(np.random.rand(6, 10, 18).astype("f"), 0, 3, use_gpu)
self._compare(np.random.rand(6, 7, 18).astype("f"), 0, 3, use_gpu)
self._compare(np.random.rand(6, 7, 9).astype("f"), 0, 3, use_gpu)
for dtype in _TEST_DTYPES:
self._compare(self._makeData((6, 10, 18), dtype), 0, 3)
self._compare(self._makeData((6, 7, 18), dtype), 0, 3)
self._compare(self._makeData((6, 7, 9), dtype), 0, 3)
def _RunAndVerify(self, use_gpu, large_num_splits=False):
def _RunAndVerify(self, dtype, large_num_splits=False):
# Random dims of rank 5
shape = np.random.randint(0, 5, size=5)
split_dim = np.random.randint(0, 5)
@ -206,8 +209,8 @@ class SplitOpTest(test.TestCase):
else:
num_split = np.random.randint(2, 8)
shape[split_dim] = np.random.randint(2, 5) * num_split
inp = np.random.rand(*shape).astype("f")
with self.test_session(use_gpu=use_gpu) as sess:
inp = self._makeData(shape, dtype)
with self.test_session(use_gpu=True) as sess:
result = sess.run(
array_ops.split(
value=inp, num_or_size_splits=num_split, axis=split_dim))
@ -220,20 +223,19 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(result[i], inp[slices])
def testRandom(self):
for _ in range(5):
self._RunAndVerify(use_gpu=False)
self._RunAndVerify(use_gpu=True)
self._RunAndVerify(use_gpu=True, large_num_splits=True)
self._RunAndVerifyVariable(use_gpu=False)
self._RunAndVerifyVariable(use_gpu=True)
self._RunAndVerifyVariable(use_gpu=True, large_num_splits=True)
for dtype in _TEST_DTYPES:
for _ in range(5):
self._RunAndVerify(dtype)
self._RunAndVerify(dtype, large_num_splits=True)
self._RunAndVerifyVariable(dtype)
self._RunAndVerifyVariable(dtype, large_num_splits=True)
def _testGradientsSimple(self, use_gpu):
inp = np.random.rand(4, 4).astype("f")
with self.test_session(use_gpu=use_gpu):
def _testGradientsSimple(self, dtype):
inp = self._makeData((4, 4), dtype)
with self.test_session(use_gpu=True):
inp_tensor = ops.convert_to_tensor(inp)
s = array_ops.split(value=inp_tensor, num_or_size_splits=4, axis=1)
inp_grads = [np.random.rand(4, 1).astype("f") for _ in range(4)]
inp_grads = [self._makeData((4, 1), dtype)for _ in range(4)]
grad_tensors = [constant_op.constant(x) for x in inp_grads]
grad = gradients_impl.gradients(s, [inp_tensor], grad_tensors)[0]
result = grad.eval()
@ -241,10 +243,9 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(result[:, i:i + 1], inp_grads[i])
def testGradientsAll(self):
self._testGradientsSimple(use_gpu=False)
self._testGradientsSimple(use_gpu=True)
self._testGradientsSimpleVariable(use_gpu=False)
self._testGradientsSimpleVariable(use_gpu=True)
for dtype in _TEST_DTYPES:
self._testGradientsSimple(dtype)
self._testGradientsSimpleVariable(dtype)
def testShapeFunctionEdgeCases(self):
# split_dim greater than rank of input.