Let Concat properly handle concat dim > 2^31 when dealing with Very Large Tensors.
Change: 123258451
This commit is contained in:
parent
6e89233b74
commit
8e9f29598a
@ -76,7 +76,7 @@ class ConcatOp : public OpKernel {
|
||||
for (int d = 0; d < concat_dim; ++d) {
|
||||
inputs_flat_dim0 *= input_shape.dim_size(d);
|
||||
}
|
||||
int output_concat_dim = 0;
|
||||
int64 output_concat_dim = 0;
|
||||
const bool input_is_scalar = IsLegacyScalar(input_shape);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
const auto in = values[i];
|
||||
|
@ -412,6 +412,17 @@ class ConcatOpTest(tf.test.TestCase):
|
||||
self.assertEqual(n + 3, after - before)
|
||||
print("graph = ", [x.name for x in g.get_operations()])
|
||||
|
||||
def testConcatLargeTensors(self):
|
||||
# CPU-only test, because it fails on GPUs with <= 4GB memory.
|
||||
with tf.device("/cpu:0"):
|
||||
a = tf.ones([2**31 + 6], dtype=tf.int8)
|
||||
b = tf.zeros([1024], dtype=tf.int8)
|
||||
onezeros = tf.concat(0, [a, b])
|
||||
with self.test_session(use_gpu=False):
|
||||
# TODO(dga): Add more depth to this test to validate correctness,
|
||||
# not just non-crashingness, once other large tensor fixes have gone in.
|
||||
_ = onezeros.eval()
|
||||
|
||||
|
||||
class ConcatOffsetTest(tf.test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user