Add tf.concat GPU kernels for complex64/complex128.

Also update concat_op_test to always use_gpu=True where previously it was running each test twice.
Change: 152755973
This commit is contained in:
RJ Ryan 2017-04-10 16:13:55 -08:00 committed by TensorFlower Gardener
parent ace2aa7b6f
commit 2f5fde8dd9
4 changed files with 76 additions and 74 deletions

View File

@ -113,6 +113,8 @@ void ConcatGPU(
Tensor* output, typename TTypes<T, 2>::Tensor* output_flat);
TF_CALL_GPU_NUMBER_TYPES(REGISTER);
TF_CALL_complex64(REGISTER);
TF_CALL_complex128(REGISTER);
REGISTER(bfloat16);
#undef REGISTER

View File

@ -198,15 +198,23 @@ void ConcatGPUImpl(const Eigen::GpuDevice& gpu_device,
int split_size, typename TTypes<T, 2>::Matrix* output);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32);
TF_CALL_complex64(REGISTER_GPUCONCAT32);
TF_CALL_complex128(REGISTER_GPUCONCAT32);
REGISTER_GPUCONCAT32(bfloat16);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64);
TF_CALL_complex64(REGISTER_GPUCONCAT64);
TF_CALL_complex128(REGISTER_GPUCONCAT64);
REGISTER_GPUCONCAT64(bfloat16);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32);
TF_CALL_complex64(REGISTER_GPU32);
TF_CALL_complex128(REGISTER_GPU32);
REGISTER_GPU32(bfloat16);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64);
TF_CALL_complex64(REGISTER_GPU64);
TF_CALL_complex128(REGISTER_GPU64);
REGISTER_GPU64(bfloat16);
#undef REGISTER_GPUCONCAT32

View File

@ -193,6 +193,8 @@ REGISTER_CONCAT(bfloat16);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
REGISTER_GPU(bfloat16);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
#undef REGISTER_GPU
// A special GPU kernel for int32.

View File

@ -35,7 +35,7 @@ from tensorflow.python.platform import test
class ConcatOpTest(test.TestCase):
def testHStack(self):
with self.test_session():
with self.test_session(use_gpu=True):
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
c = array_ops.concat([p1, p2], 0)
@ -50,7 +50,7 @@ class ConcatOpTest(test.TestCase):
self.assertAllEqual(result[4:, :], params[p2])
def testVStack(self):
with self.test_session():
with self.test_session(use_gpu=True):
p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
c = array_ops.concat([p1, p2], 1)
@ -76,7 +76,7 @@ class ConcatOpTest(test.TestCase):
self.assertAllEqual(result[2:, :], p2)
def testRefType(self):
with self.test_session():
with self.test_session(use_gpu=True):
p1 = np.random.rand(4, 4).astype("f")
p2 = np.random.rand(4, 4).astype("f")
v1 = variables.Variable(p1)
@ -89,7 +89,7 @@ class ConcatOpTest(test.TestCase):
self.assertAllEqual(result[:4, :], p1)
self.assertAllEqual(result[4:, :], p2)
def _testRandom(self, dtype, use_gpu=False):
def _testRandom(self, dtype):
# Random dims of rank 5
shape = np.random.randint(1, 5, size=5)
# Random number of tensors, but always > 1.
@ -101,7 +101,7 @@ class ConcatOpTest(test.TestCase):
dtype_feed = dtypes.float32
else:
dtype_feed = dtype
with self.test_session(use_gpu=use_gpu):
with self.test_session(use_gpu=True):
p = []
for i in np.arange(num_tensors):
input_shape = shape
@ -139,11 +139,11 @@ class ConcatOpTest(test.TestCase):
def testRandom(self):
self._testRandom(dtypes.float32)
self._testRandom(dtypes.float32, use_gpu=True)
self._testRandom(dtypes.int16)
self._testRandom(dtypes.int32, use_gpu=True)
self._testRandom(dtypes.int32)
self._testRandom(dtypes.bfloat16)
self._testRandom(dtypes.bfloat16, use_gpu=True)
self._testRandom(dtypes.complex64)
self._testRandom(dtypes.complex128)
def testInvalidConcatDimTypeAndShape(self):
a = variables.Variable(constant_op.constant(1.0, shape=[1]))
@ -166,38 +166,42 @@ class ConcatOpTest(test.TestCase):
with self.assertRaises(ValueError):
array_ops.concat(1, constant_op.constant(0, shape=[1]))
def _testGradientsSimple(self, use_gpu):
def _testGradientsSimple(self, dtype):
# Test both positive and negative concat axis.
# -2 and 1 correspond to the same axis for 3-dimensional tensors.
for axis in [-2, 1]:
with self.test_session(use_gpu=use_gpu):
with self.test_session(use_gpu=True):
inp = []
inp_tensors = []
for x in [1, 2, 6]:
shape = [10, x, 2]
t = np.random.rand(*shape).astype("f")
t = np.random.rand(*shape).astype(dtype.as_numpy_dtype)
if dtype.is_complex:
t += -1j * t
inp.append(t)
inp_tensors.append(
constant_op.constant(
[float(y) for y in t.flatten()],
t.flatten(),
shape=shape,
dtype=dtypes.float32))
dtype=dtype))
c = array_ops.concat(inp_tensors, axis)
output_shape = [10, 9, 2]
grad_inp = np.random.rand(*output_shape).astype("f")
grad_inp = np.random.rand(*output_shape).astype(dtype.as_numpy_dtype)
if dtype.is_complex:
grad_inp += -1j * grad_inp
grad_tensor = constant_op.constant(
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad_inp.flatten(), shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, axis)
result = concated_grad.eval()
self.assertAllEqual(result, grad_inp)
def testGradientsSimpleAll(self):
self._testGradientsSimple(use_gpu=True)
self._testGradientsSimple(use_gpu=False)
def testGradientsSimple(self):
self._testGradientsSimple(dtypes.float32)
self._testGradientsSimple(dtypes.complex64)
def _testGradientsFirstDim(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
def testGradientsFirstDim(self):
with self.test_session(use_gpu=True):
inp = []
inp_tensors = []
for x in [1, 2, 6]:
@ -206,29 +210,25 @@ class ConcatOpTest(test.TestCase):
inp.append(t)
inp_tensors.append(
constant_op.constant(
[float(y) for y in t.flatten()],
t.flatten(),
shape=shape,
dtype=dtypes.float32))
c = array_ops.concat(inp_tensors, 0)
output_shape = [9, 10, 2]
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad_inp.flatten(), shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, 0)
result = concated_grad.eval()
self.assertAllEqual(result, grad_inp)
def testGradientsFirstDimAll(self):
self._testGradientsFirstDim(use_gpu=False)
self._testGradientsFirstDim(use_gpu=True)
def _testGradientsLastDim(self, use_gpu):
def testGradientsLastDim(self):
# Test both positive and negative concat axis.
# -1 and 2 correspond to the same axis for 3-dimensional tensors.
for axis in [-1, 2]:
with self.test_session(use_gpu=use_gpu):
with self.test_session(use_gpu=True):
inp = []
inp_tensors = []
for x in [1, 2, 6]:
@ -237,25 +237,21 @@ class ConcatOpTest(test.TestCase):
inp.append(t)
inp_tensors.append(
constant_op.constant(
[float(y) for y in t.flatten()],
t.flatten(),
shape=shape,
dtype=dtypes.float32))
c = array_ops.concat(inp_tensors, 2)
output_shape = [10, 2, 9]
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad_inp.flatten(), shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, axis)
result = concated_grad.eval()
self.assertAllEqual(result, grad_inp)
def testGradientsLastDimAll(self):
self._testGradientsLastDim(use_gpu=False)
self._testGradientsLastDim(use_gpu=True)
def _RunAndVerifyGradientsRandom(self, use_gpu):
def _RunAndVerifyGradientsRandom(self):
# Random dims of rank 5
input_shape = np.random.randint(1, 5, size=5)
# Random number of tensors
@ -263,7 +259,7 @@ class ConcatOpTest(test.TestCase):
# Random dim to concat on
concat_dim = np.random.randint(5)
concat_dim_sizes = np.random.randint(1, 5, size=num_tensors)
with self.test_session(use_gpu=use_gpu):
with self.test_session(use_gpu=True):
inp = []
inp_tensors = []
for x in concat_dim_sizes:
@ -272,16 +268,13 @@ class ConcatOpTest(test.TestCase):
t = np.random.rand(*shape).astype("f")
inp.append(t)
inp_tensors.append(
constant_op.constant(
[float(y) for y in t.flatten()],
shape=shape,
dtype=dtypes.float32))
constant_op.constant(t.flatten(), shape=shape,
dtype=dtypes.float32))
c = array_ops.concat(inp_tensors, concat_dim)
output_shape = input_shape
output_shape[concat_dim] = concat_dim_sizes.sum()
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad_tensor = constant_op.constant(grad_inp.flatten(), shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, concat_dim)
result = concated_grad.eval()
@ -290,8 +283,7 @@ class ConcatOpTest(test.TestCase):
def testGradientsRandom(self):
for _ in range(5):
self._RunAndVerifyGradientsRandom(use_gpu=False)
self._RunAndVerifyGradientsRandom(use_gpu=True)
self._RunAndVerifyGradientsRandom()
def testGradientWithUnknownInputDim(self):
with self.test_session(use_gpu=True):
@ -302,7 +294,7 @@ class ConcatOpTest(test.TestCase):
output_shape = [10, 2, 9]
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
[float(inp) for inp in grad_inp.flatten()], shape=output_shape)
grad_inp.flatten(), shape=output_shape)
grad = gradients_impl.gradients([c], [x, y], [grad_tensor])
concated_grad = array_ops.concat(grad, 2)
@ -364,24 +356,23 @@ class ConcatOpTest(test.TestCase):
def testZeroSize(self):
# Verify that concat doesn't crash and burn for zero size inputs
np.random.seed(7)
for use_gpu in False, True:
with self.test_session(use_gpu=use_gpu) as sess:
for shape0 in (), (2,):
axis = len(shape0)
for shape1 in (), (3,):
for n0 in 0, 1, 2:
for n1 in 0, 1, 2:
x0 = np.random.randn(*(shape0 + (n0,) + shape1))
x1 = np.random.randn(*(shape0 + (n1,) + shape1))
correct = np.concatenate([x0, x1], axis=axis)
# TODO(irving): Make tf.concat handle map, then drop list().
xs = list(map(constant_op.constant, [x0, x1]))
c = array_ops.concat(xs, axis)
self.assertAllEqual(c.eval(), correct)
# Check gradients
dc = np.random.randn(*c.get_shape().as_list())
dxs = sess.run(gradients_impl.gradients(c, xs, dc))
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
with self.test_session(use_gpu=True) as sess:
for shape0 in (), (2,):
axis = len(shape0)
for shape1 in (), (3,):
for n0 in 0, 1, 2:
for n1 in 0, 1, 2:
x0 = np.random.randn(*(shape0 + (n0,) + shape1))
x1 = np.random.randn(*(shape0 + (n1,) + shape1))
correct = np.concatenate([x0, x1], axis=axis)
# TODO(irving): Make tf.concat handle map, then drop list().
xs = list(map(constant_op.constant, [x0, x1]))
c = array_ops.concat(xs, axis)
self.assertAllEqual(c.eval(), correct)
# Check gradients
dc = np.random.randn(*c.get_shape().as_list())
dxs = sess.run(gradients_impl.gradients(c, xs, dc))
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
def testTensorConcatDim0Grad(self):
x_shapes = [[20, 7, 3], [10, 7, 3], [14, 7, 3]]
@ -565,7 +556,7 @@ class ConcatOpTest(test.TestCase):
c = array_ops.concat(inp_tensors, axis)
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad_inp.flatten(), shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, axis)
result = concated_grad.eval(feed_dict=feed_dict)
@ -578,7 +569,7 @@ class ConcatOpTest(test.TestCase):
array_ops.concat(inp_tensors, axis), gather_indexes)
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
[float(x) for x in grad_inp.flatten()], shape=output_shape)
grad_inp.flatten(), shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.gather(
array_ops.concat(grad, axis), gather_indexes)
@ -617,15 +608,14 @@ class ConcatOpTest(test.TestCase):
class ConcatOffsetTest(test.TestCase):
def testBasic(self):
for use_gpu in [False, True]:
with self.test_session(use_gpu=use_gpu) as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
off = gen_array_ops._concat_offset(cdim, [s0, s1, s2])
ans = sess.run(off)
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
with self.test_session(use_gpu=True) as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32)
off = gen_array_ops._concat_offset(cdim, [s0, s1, s2])
ans = sess.run(off)
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
def testNotVector(self):
with self.test_session() as sess: