Support Iterables in ops.convert_n_to_tensor_or_indexed_slices and math_ops.add_n.
Fixes a tf.add_n papercut where it cannot be used with common iterable types beyond list/tuple. It only iterates the collection once, so it need not be any more specific than requiring an iterable. PiperOrigin-RevId: 314477113 Change-Id: I9bc1d2d3b2296e72f085d23481738df616b0ec6c
This commit is contained in:
parent
bd20260350
commit
d4bcf76529
@ -327,8 +327,8 @@ def internal_convert_n_to_tensor_or_indexed_slices(values,
|
|||||||
unmodified.
|
unmodified.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
|
values: An iterable of `None`, `IndexedSlices`, `SparseTensor`, or objects
|
||||||
can be consumed by `convert_to_tensor()`.
|
that can be consumed by `convert_to_tensor()`.
|
||||||
dtype: (Optional.) The required `DType` of the returned `Tensor` or
|
dtype: (Optional.) The required `DType` of the returned `Tensor` or
|
||||||
`IndexedSlices`.
|
`IndexedSlices`.
|
||||||
name: (Optional.) A name prefix to used when a new `Tensor` is created, in
|
name: (Optional.) A name prefix to used when a new `Tensor` is created, in
|
||||||
@ -344,8 +344,8 @@ def internal_convert_n_to_tensor_or_indexed_slices(values,
|
|||||||
RuntimeError: If a registered conversion function returns an invalid
|
RuntimeError: If a registered conversion function returns an invalid
|
||||||
value.
|
value.
|
||||||
"""
|
"""
|
||||||
if not isinstance(values, collections.Sequence):
|
if not isinstance(values, collections.Iterable):
|
||||||
raise TypeError("values must be a sequence.")
|
raise TypeError("values must be iterable.")
|
||||||
ret = []
|
ret = []
|
||||||
for i, value in enumerate(values):
|
for i, value in enumerate(values):
|
||||||
if value is None:
|
if value is None:
|
||||||
|
@ -70,6 +70,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
from six.moves import builtins
|
from six.moves import builtins
|
||||||
@ -3421,12 +3423,12 @@ def add_n(inputs, name=None):
|
|||||||
ValueError: If `inputs` don't all have same shape and dtype or the shape
|
ValueError: If `inputs` don't all have same shape and dtype or the shape
|
||||||
cannot be inferred.
|
cannot be inferred.
|
||||||
"""
|
"""
|
||||||
if not inputs or not isinstance(inputs, (list, tuple)):
|
if not inputs or not isinstance(inputs, collections.Iterable):
|
||||||
raise ValueError("inputs must be a list of at least one "
|
raise ValueError("inputs must be an iterable of at least one "
|
||||||
"Tensor/IndexedSlices with the same dtype and shape")
|
"Tensor/IndexedSlices with the same dtype and shape")
|
||||||
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
|
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
|
||||||
if not all(isinstance(x, (ops.Tensor, ops.IndexedSlices)) for x in inputs):
|
if not all(isinstance(x, (ops.Tensor, ops.IndexedSlices)) for x in inputs):
|
||||||
raise ValueError("inputs must be a list of at least one "
|
raise ValueError("inputs must be an iterable of at least one "
|
||||||
"Tensor/IndexedSlices with the same dtype and shape")
|
"Tensor/IndexedSlices with the same dtype and shape")
|
||||||
|
|
||||||
if len(inputs) == 1:
|
if len(inputs) == 1:
|
||||||
|
@ -423,6 +423,16 @@ class AddNTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(slc_as_dense, math_ops.add_n([slc]))
|
self.assertAllEqual(slc_as_dense, math_ops.add_n([slc]))
|
||||||
self.assertAllEqual(2 * slc_as_dense, math_ops.add_n([slc, slc]))
|
self.assertAllEqual(2 * slc_as_dense, math_ops.add_n([slc, slc]))
|
||||||
|
|
||||||
|
def test_iterable(self):
|
||||||
|
"""Test that add_n supports iterables (e.g. generators and dict values)."""
|
||||||
|
def fn():
|
||||||
|
yield 1
|
||||||
|
yield 2
|
||||||
|
values_dict = {"a": 1, "b": 2}
|
||||||
|
with test_util.use_gpu():
|
||||||
|
self.assertAllEqual(3, math_ops.add_n(fn()))
|
||||||
|
self.assertAllEqual(3, math_ops.add_n(values_dict.values()))
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class DivAndModTest(test_util.TensorFlowTestCase):
|
class DivAndModTest(test_util.TensorFlowTestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user