Merge pull request #39241 from yongtang:39229-map_fn-empty-list
PiperOrigin-RevId: 318950663 Change-Id: Ia75f5bf028a94abbb7ebf2c7d3193309bdef9029
This commit is contained in:
commit
d80b6adbd2
@ -237,6 +237,12 @@ class MapFnTest(test.TestCase):
|
|||||||
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
|
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
|
||||||
self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
|
self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testMapEmptyList(self):
|
||||||
|
x = []
|
||||||
|
with self.assertRaisesRegexp(ValueError, r"elems must be a Tensor or"):
|
||||||
|
_ = map_fn.map_fn(lambda e: e, x)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -267,7 +267,7 @@ def map_fn(fn,
|
|||||||
elems: A tensor or (possibly nested) sequence of tensors, each of which will
|
elems: A tensor or (possibly nested) sequence of tensors, each of which will
|
||||||
be unstacked along their first dimension. `fn` will be applied to the
|
be unstacked along their first dimension. `fn` will be applied to the
|
||||||
nested sequence of the resulting slices. `elems` may include ragged and
|
nested sequence of the resulting slices. `elems` may include ragged and
|
||||||
sparse tensors.
|
sparse tensors. `elems` must consist of at least one tensor.
|
||||||
dtype: Deprecated: Equivalent to `fn_output_signature`.
|
dtype: Deprecated: Equivalent to `fn_output_signature`.
|
||||||
parallel_iterations: (optional) The number of iterations allowed to run in
|
parallel_iterations: (optional) The number of iterations allowed to run in
|
||||||
parallel. When graph building, the default value is 10. While executing
|
parallel. When graph building, the default value is 10. While executing
|
||||||
@ -296,7 +296,7 @@ def map_fn(fn,
|
|||||||
TypeError: if `fn` is not callable or the structure of the output of
|
TypeError: if `fn` is not callable or the structure of the output of
|
||||||
`fn` and `fn_output_signature` do not match.
|
`fn` and `fn_output_signature` do not match.
|
||||||
ValueError: if the lengths of the output of `fn` and `fn_output_signature`
|
ValueError: if the lengths of the output of `fn` and `fn_output_signature`
|
||||||
do not match.
|
do not match, or if the `elems` does not contain any tensor.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@ -375,6 +375,13 @@ def map_fn(fn,
|
|||||||
|
|
||||||
# Flatten the input tensors, and get the TypeSpec for each one.
|
# Flatten the input tensors, and get the TypeSpec for each one.
|
||||||
elems_flat = nest.flatten(elems)
|
elems_flat = nest.flatten(elems)
|
||||||
|
|
||||||
|
# Check in case this is an empty list
|
||||||
|
if len(elems_flat) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"elems must be a Tensor or (possibly nested) sequence of Tensors. "
|
||||||
|
"Got {}, which does not contain any Tensors.".format(elems))
|
||||||
|
|
||||||
elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
|
elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
|
||||||
elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)
|
elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user