Merge pull request #39241 from yongtang:39229-map_fn-empty-list
PiperOrigin-RevId: 318950663 Change-Id: Ia75f5bf028a94abbb7ebf2c7d3193309bdef9029
This commit is contained in:
commit
d80b6adbd2
@ -237,6 +237,12 @@ class MapFnTest(test.TestCase):
|
||||
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
|
||||
self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testMapEmptyList(self):
|
||||
x = []
|
||||
with self.assertRaisesRegexp(ValueError, r"elems must be a Tensor or"):
|
||||
_ = map_fn.map_fn(lambda e: e, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -267,7 +267,7 @@ def map_fn(fn,
|
||||
elems: A tensor or (possibly nested) sequence of tensors, each of which will
|
||||
be unstacked along their first dimension. `fn` will be applied to the
|
||||
nested sequence of the resulting slices. `elems` may include ragged and
|
||||
sparse tensors.
|
||||
sparse tensors. `elems` must consist of at least one tensor.
|
||||
dtype: Deprecated: Equivalent to `fn_output_signature`.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run in
|
||||
parallel. When graph building, the default value is 10. While executing
|
||||
@ -296,7 +296,7 @@ def map_fn(fn,
|
||||
TypeError: if `fn` is not callable or the structure of the output of
|
||||
`fn` and `fn_output_signature` do not match.
|
||||
ValueError: if the lengths of the output of `fn` and `fn_output_signature`
|
||||
do not match.
|
||||
do not match, or if the `elems` does not contain any tensor.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -375,6 +375,13 @@ def map_fn(fn,
|
||||
|
||||
# Flatten the input tensors, and get the TypeSpec for each one.
|
||||
elems_flat = nest.flatten(elems)
|
||||
|
||||
# Check in case this is an empty list
|
||||
if len(elems_flat) == 0:
|
||||
raise ValueError(
|
||||
"elems must be a Tensor or (possibly nested) sequence of Tensors. "
|
||||
"Got {}, which does not contain any Tensors.".format(elems))
|
||||
|
||||
elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
|
||||
elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user