Fix static shape computation of dynamic_stitch when all tensors are empty.
PiperOrigin-RevId: 274209565
This commit is contained in:
parent
8c21158cd4
commit
5411ea4ed4
@ -93,7 +93,7 @@ Status DynamicStitchShapeFunction(InferenceContext* c) {
|
|||||||
TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
|
TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
|
||||||
|
|
||||||
bool all_indices_constant = true;
|
bool all_indices_constant = true;
|
||||||
int32 max_index = 0;
|
int32 max_index = -1;
|
||||||
ShapeHandle extra_shape = c->UnknownShape();
|
ShapeHandle extra_shape = c->UnknownShape();
|
||||||
for (int i = 0; i < num_partitions; ++i) {
|
for (int i = 0; i < num_partitions; ++i) {
|
||||||
const Tensor* indices_t = c->input_tensor(i);
|
const Tensor* indices_t = c->input_tensor(i);
|
||||||
|
@ -131,6 +131,20 @@ class DynamicStitchTestBase(object):
|
|||||||
# Dimension 0 is max(flatten(indices))+1.
|
# Dimension 0 is max(flatten(indices))+1.
|
||||||
self.assertEqual([8, 2], stitched_t.get_shape().as_list())
|
self.assertEqual([8, 2], stitched_t.get_shape().as_list())
|
||||||
|
|
||||||
|
def testAllZeroSizeTensor(self):
|
||||||
|
indices = [
|
||||||
|
array_ops.zeros([0], dtype=dtypes.int32),
|
||||||
|
array_ops.zeros([0], dtype=dtypes.int32)
|
||||||
|
]
|
||||||
|
data = [
|
||||||
|
array_ops.zeros([0, 2], dtype=dtypes.int32),
|
||||||
|
array_ops.zeros([0, 2], dtype=dtypes.int32)
|
||||||
|
]
|
||||||
|
stitched_t = self.stitch_op(indices, data)
|
||||||
|
stitched_val = self.evaluate(stitched_t)
|
||||||
|
self.assertAllEqual(np.zeros((0, 2)), stitched_val)
|
||||||
|
self.assertEqual([0, 2], stitched_t.get_shape().as_list())
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testHigherRank(self):
|
def testHigherRank(self):
|
||||||
indices = [
|
indices = [
|
||||||
|
Loading…
Reference in New Issue
Block a user