Fix static shape computation of dynamic_stitch when all tensors are empty.

PiperOrigin-RevId: 274209565
This commit is contained in:
Adria Puigdomenech 2019-10-11 11:10:39 -07:00 committed by TensorFlower Gardener
parent 8c21158cd4
commit 5411ea4ed4
2 changed files with 15 additions and 1 deletions

View File

@ -93,7 +93,7 @@ Status DynamicStitchShapeFunction(InferenceContext* c) {
TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
bool all_indices_constant = true;
int32 max_index = 0;
int32 max_index = -1;
ShapeHandle extra_shape = c->UnknownShape();
for (int i = 0; i < num_partitions; ++i) {
const Tensor* indices_t = c->input_tensor(i);

View File

@ -131,6 +131,20 @@ class DynamicStitchTestBase(object):
# Dimension 0 is max(flatten(indices))+1.
self.assertEqual([8, 2], stitched_t.get_shape().as_list())
def testAllZeroSizeTensor(self):
indices = [
array_ops.zeros([0], dtype=dtypes.int32),
array_ops.zeros([0], dtype=dtypes.int32)
]
data = [
array_ops.zeros([0, 2], dtype=dtypes.int32),
array_ops.zeros([0, 2], dtype=dtypes.int32)
]
stitched_t = self.stitch_op(indices, data)
stitched_val = self.evaluate(stitched_t)
self.assertAllEqual(np.zeros((0, 2)), stitched_val)
self.assertEqual([0, 2], stitched_t.get_shape().as_list())
@test_util.run_deprecated_v1
def testHigherRank(self):
indices = [