Merge pull request #39241 from yongtang:39229-map_fn-empty-list

PiperOrigin-RevId: 318950663
Change-Id: Ia75f5bf028a94abbb7ebf2c7d3193309bdef9029
This commit is contained in:
TensorFlower Gardener 2020-06-29 21:05:52 -07:00
commit d80b6adbd2
2 changed files with 15 additions and 2 deletions

View File

@ -237,6 +237,12 @@ class MapFnTest(test.TestCase):
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
@test_util.run_in_graph_and_eager_modes
def testMapEmptyList(self):
x = []
with self.assertRaisesRegexp(ValueError, r"elems must be a Tensor or"):
_ = map_fn.map_fn(lambda e: e, x)
if __name__ == "__main__":
test.main()

View File

@ -267,7 +267,7 @@ def map_fn(fn,
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unstacked along their first dimension. `fn` will be applied to the
nested sequence of the resulting slices. `elems` may include ragged and
sparse tensors.
sparse tensors. `elems` must consist of at least one tensor.
dtype: Deprecated: Equivalent to `fn_output_signature`.
parallel_iterations: (optional) The number of iterations allowed to run in
parallel. When graph building, the default value is 10. While executing
@ -296,7 +296,7 @@ def map_fn(fn,
TypeError: if `fn` is not callable or the structure of the output of
`fn` and `fn_output_signature` do not match.
ValueError: if the lengths of the output of `fn` and `fn_output_signature`
do not match.
do not match, or if the `elems` does not contain any tensor.
Examples:
@ -375,6 +375,13 @@ def map_fn(fn,
# Flatten the input tensors, and get the TypeSpec for each one.
elems_flat = nest.flatten(elems)
# Check in case this is an empty list
if len(elems_flat) == 0:
raise ValueError(
"elems must be a Tensor or (possibly nested) sequence of Tensors. "
"Got {}, which does not contain any Tensors.".format(elems))
elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)