Support Iterables in ops.convert_n_to_tensor_or_indexed_slices and math_ops.add_n.

Fixes a tf.add_n papercut where it cannot be used with common iterable types beyond list/tuple. It only iterates the collection once, so it need not be any more specific than requiring an iterable.

PiperOrigin-RevId: 314477113
Change-Id: I9bc1d2d3b2296e72f085d23481738df616b0ec6c
This commit is contained in:
RJ Skerry-Ryan 2020-06-02 23:35:42 -07:00 committed by TensorFlower Gardener
parent bd20260350
commit d4bcf76529
3 changed files with 19 additions and 7 deletions

View File

@ -327,8 +327,8 @@ def internal_convert_n_to_tensor_or_indexed_slices(values,
unmodified.
Args:
values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
can be consumed by `convert_to_tensor()`.
values: An iterable of `None`, `IndexedSlices`, `SparseTensor`, or objects
that can be consumed by `convert_to_tensor()`.
dtype: (Optional.) The required `DType` of the returned `Tensor` or
`IndexedSlices`.
name: (Optional.) A name prefix to used when a new `Tensor` is created, in
@ -344,8 +344,8 @@ def internal_convert_n_to_tensor_or_indexed_slices(values,
RuntimeError: If a registered conversion function returns an invalid
value.
"""
if not isinstance(values, collections.Sequence):
raise TypeError("values must be a sequence.")
if not isinstance(values, collections.Iterable):
raise TypeError("values must be iterable.")
ret = []
for i, value in enumerate(values):
if value is None:

View File

@ -70,6 +70,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
import six
from six.moves import builtins
@ -3421,12 +3423,12 @@ def add_n(inputs, name=None):
ValueError: If `inputs` don't all have same shape and dtype or the shape
cannot be inferred.
"""
if not inputs or not isinstance(inputs, (list, tuple)):
raise ValueError("inputs must be a list of at least one "
if not inputs or not isinstance(inputs, collections.Iterable):
raise ValueError("inputs must be an iterable of at least one "
"Tensor/IndexedSlices with the same dtype and shape")
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
if not all(isinstance(x, (ops.Tensor, ops.IndexedSlices)) for x in inputs):
raise ValueError("inputs must be a list of at least one "
raise ValueError("inputs must be an iterable of at least one "
"Tensor/IndexedSlices with the same dtype and shape")
if len(inputs) == 1:

View File

@ -423,6 +423,16 @@ class AddNTest(test_util.TensorFlowTestCase):
self.assertAllEqual(slc_as_dense, math_ops.add_n([slc]))
self.assertAllEqual(2 * slc_as_dense, math_ops.add_n([slc, slc]))
def test_iterable(self):
"""Test that add_n supports iterables (e.g. generators and dict values)."""
def fn():
yield 1
yield 2
values_dict = {"a": 1, "b": 2}
with test_util.use_gpu():
self.assertAllEqual(3, math_ops.add_n(fn()))
self.assertAllEqual(3, math_ops.add_n(values_dict.values()))
@test_util.run_all_in_graph_and_eager_modes
class DivAndModTest(test_util.TensorFlowTestCase):