Return ValueError in case of empty list input for tf.map_fn
This PR tries to address the issue raised in 39229 where empty lists input was not checked and throw out a non-obvious error: ```python >>> import numpy as np >>> import tensorflow as tf >>> fn = lambda x: x >>> tf.map_fn(fn, []) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Library/Python/3.7/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func return func(*args, **kwargs) File "/Library/Python/3.7/site-packages/tensorflow/python/ops/map_fn.py", line 425, in map_fn_v2 name=name) File "/Library/Python/3.7/site-packages/tensorflow/python/ops/map_fn.py", line 213, in map_fn static_shape = elems_flat[0].shape IndexError: list index out of range >>> ``` In case of empty list the behavior is undefined as we even don't know the output dtype. This PR update to perform a check and thrown out `ValueError("elems must not be empty")` to help clarify. This PR fixes 39229. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
ae76544efc
commit
0ee6b3a69d
@ -217,6 +217,12 @@ class MapFnTest(test.TestCase):
|
||||
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
|
||||
self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testMapEmptyList(self):
|
||||
x = []
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"elems must be a Tensor or"):
|
||||
_ = map_fn.map_fn(lambda e: e, x)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -375,6 +375,13 @@ def map_fn(fn,
|
||||
|
||||
# Flatten the input tensors, and get the TypeSpec for each one.
|
||||
elems_flat = nest.flatten(elems)
|
||||
|
||||
# Check in case this is an empty list
|
||||
if len(elems_flat) == 0:
|
||||
raise ValueError(
|
||||
"elems must be a Tensor or (possibly nested) sequence of Tensors. "
|
||||
"Got {}, which does not contain any Tensors.".format(elems))
|
||||
|
||||
elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
|
||||
elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user