Let Concat properly handle concat dim > 2^31 when dealing with Very Large Tensors.

Change: 123258451
This commit is contained in:
David G. Andersen 2016-05-25 14:10:46 -08:00 committed by TensorFlower Gardener
parent 6e89233b74
commit 8e9f29598a
2 changed files with 12 additions and 1 deletions

View File

@ -76,7 +76,7 @@ class ConcatOp : public OpKernel {
for (int d = 0; d < concat_dim; ++d) {
inputs_flat_dim0 *= input_shape.dim_size(d);
}
int output_concat_dim = 0;
int64 output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
const auto in = values[i];

View File

@ -412,6 +412,17 @@ class ConcatOpTest(tf.test.TestCase):
self.assertEqual(n + 3, after - before)
print("graph = ", [x.name for x in g.get_operations()])
def testConcatLargeTensors(self):
# CPU-only test, because it fails on GPUs with <= 4GB memory.
with tf.device("/cpu:0"):
a = tf.ones([2**31 + 6], dtype=tf.int8)
b = tf.zeros([1024], dtype=tf.int8)
onezeros = tf.concat(0, [a, b])
with self.test_session(use_gpu=False):
# TODO(dga): Add more depth to this test to validate correctness,
# not just non-crashingness, once other large tensor fixes have gone in.
_ = onezeros.eval()
class ConcatOffsetTest(tf.test.TestCase):