diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index b9115a1ee00..dda25e40789 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -93,7 +93,7 @@ Status DynamicStitchShapeFunction(InferenceContext* c) { TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions)); bool all_indices_constant = true; - int32 max_index = 0; + int32 max_index = -1; ShapeHandle extra_shape = c->UnknownShape(); for (int i = 0; i < num_partitions; ++i) { const Tensor* indices_t = c->input_tensor(i); diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py index 4d57c1b264a..50d11a62793 100644 --- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -131,6 +131,20 @@ class DynamicStitchTestBase(object): # Dimension 0 is max(flatten(indices))+1. self.assertEqual([8, 2], stitched_t.get_shape().as_list()) + def testAllZeroSizeTensor(self): + indices = [ + array_ops.zeros([0], dtype=dtypes.int32), + array_ops.zeros([0], dtype=dtypes.int32) + ] + data = [ + array_ops.zeros([0, 2], dtype=dtypes.int32), + array_ops.zeros([0, 2], dtype=dtypes.int32) + ] + stitched_t = self.stitch_op(indices, data) + stitched_val = self.evaluate(stitched_t) + self.assertAllEqual(np.zeros((0, 2)), stitched_val) + self.assertEqual([0, 2], stitched_t.get_shape().as_list()) + @test_util.run_deprecated_v1 def testHigherRank(self): indices = [