Fix static shape computation of dynamic_stitch when all tensors are empty.
PiperOrigin-RevId: 274209565
This commit is contained in:
parent
8c21158cd4
commit
5411ea4ed4
@ -93,7 +93,7 @@ Status DynamicStitchShapeFunction(InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
|
||||
|
||||
bool all_indices_constant = true;
|
||||
int32 max_index = 0;
|
||||
int32 max_index = -1;
|
||||
ShapeHandle extra_shape = c->UnknownShape();
|
||||
for (int i = 0; i < num_partitions; ++i) {
|
||||
const Tensor* indices_t = c->input_tensor(i);
|
||||
|
@ -131,6 +131,20 @@ class DynamicStitchTestBase(object):
|
||||
# Dimension 0 is max(flatten(indices))+1.
|
||||
self.assertEqual([8, 2], stitched_t.get_shape().as_list())
|
||||
|
||||
def testAllZeroSizeTensor(self):
|
||||
indices = [
|
||||
array_ops.zeros([0], dtype=dtypes.int32),
|
||||
array_ops.zeros([0], dtype=dtypes.int32)
|
||||
]
|
||||
data = [
|
||||
array_ops.zeros([0, 2], dtype=dtypes.int32),
|
||||
array_ops.zeros([0, 2], dtype=dtypes.int32)
|
||||
]
|
||||
stitched_t = self.stitch_op(indices, data)
|
||||
stitched_val = self.evaluate(stitched_t)
|
||||
self.assertAllEqual(np.zeros((0, 2)), stitched_val)
|
||||
self.assertEqual([0, 2], stitched_t.get_shape().as_list())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testHigherRank(self):
|
||||
indices = [
|
||||
|
Loading…
Reference in New Issue
Block a user