Return ValueError in case of empty list input for tf.map_fn

This PR tries to address the issue raised in 39229 where
empty lists input was not checked and throw out a non-obvious error:
```python
>>> import numpy as np
>>> import tensorflow as tf
>>> fn = lambda x: x
>>> tf.map_fn(fn, [])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Library/Python/3.7/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "/Library/Python/3.7/site-packages/tensorflow/python/ops/map_fn.py", line 425, in map_fn_v2
    name=name)
  File "/Library/Python/3.7/site-packages/tensorflow/python/ops/map_fn.py", line 213, in map_fn
    static_shape = elems_flat[0].shape
IndexError: list index out of range
>>>
```

In case of empty list the behavior is undefined as we even don't know the output dtype.

This PR update to perform a check and thrown out
`ValueError("elems must not be empty")` to help clarify.

This PR fixes 39229.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-05-06 22:29:23 +00:00
parent ae76544efc
commit 0ee6b3a69d
2 changed files with 13 additions and 0 deletions

View File

@ -217,6 +217,12 @@ class MapFnTest(test.TestCase):
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
@test_util.run_in_graph_and_eager_modes
def testMapEmptyList(self):
x = []
with self.assertRaisesRegexp(
ValueError, r"elems must be a Tensor or"):
_ = map_fn.map_fn(lambda e: e, x)
if __name__ == "__main__":
test.main()

View File

@ -375,6 +375,13 @@ def map_fn(fn,
# Flatten the input tensors, and get the TypeSpec for each one.
elems_flat = nest.flatten(elems)
# Check in case this is an empty list
if len(elems_flat) == 0:
raise ValueError(
"elems must be a Tensor or (possibly nested) sequence of Tensors. "
"Got {}, which does not contain any Tensors.".format(elems))
elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)