From d4bcf76529a8752f70aef4e121c2fbd21406f8a3 Mon Sep 17 00:00:00 2001 From: RJ Skerry-Ryan Date: Tue, 2 Jun 2020 23:35:42 -0700 Subject: [PATCH] Support Iterables in ops.convert_n_to_tensor_or_indexed_slices and math_ops.add_n. Fixes a tf.add_n papercut where it cannot be used with common iterable types beyond list/tuple. It only iterates the collection once, so it need not be any more specific than requiring an iterable. PiperOrigin-RevId: 314477113 Change-Id: I9bc1d2d3b2296e72f085d23481738df616b0ec6c --- tensorflow/python/framework/indexed_slices.py | 8 ++++---- tensorflow/python/ops/math_ops.py | 8 +++++--- tensorflow/python/ops/math_ops_test.py | 10 ++++++++++ 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/framework/indexed_slices.py b/tensorflow/python/framework/indexed_slices.py index f85d0e77481..a2746d22650 100644 --- a/tensorflow/python/framework/indexed_slices.py +++ b/tensorflow/python/framework/indexed_slices.py @@ -327,8 +327,8 @@ def internal_convert_n_to_tensor_or_indexed_slices(values, unmodified. Args: - values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that - can be consumed by `convert_to_tensor()`. + values: An iterable of `None`, `IndexedSlices`, `SparseTensor`, or objects + that can be consumed by `convert_to_tensor()`. dtype: (Optional.) The required `DType` of the returned `Tensor` or `IndexedSlices`. name: (Optional.) A name prefix to used when a new `Tensor` is created, in @@ -344,8 +344,8 @@ def internal_convert_n_to_tensor_or_indexed_slices(values, RuntimeError: If a registered conversion function returns an invalid value. """ - if not isinstance(values, collections.Sequence): - raise TypeError("values must be a sequence.") + if not isinstance(values, collections.Iterable): + raise TypeError("values must be iterable.") ret = [] for i, value in enumerate(values): if value is None: diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 95ace7f757b..0719eb7d164 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -70,6 +70,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + import numpy as np import six from six.moves import builtins @@ -3421,12 +3423,12 @@ def add_n(inputs, name=None): ValueError: If `inputs` don't all have same shape and dtype or the shape cannot be inferred. """ - if not inputs or not isinstance(inputs, (list, tuple)): - raise ValueError("inputs must be a list of at least one " + if not inputs or not isinstance(inputs, collections.Iterable): + raise ValueError("inputs must be an iterable of at least one " "Tensor/IndexedSlices with the same dtype and shape") inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) if not all(isinstance(x, (ops.Tensor, ops.IndexedSlices)) for x in inputs): - raise ValueError("inputs must be a list of at least one " + raise ValueError("inputs must be an iterable of at least one " "Tensor/IndexedSlices with the same dtype and shape") if len(inputs) == 1: diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 66bb541049a..940966741dc 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -423,6 +423,16 @@ class AddNTest(test_util.TensorFlowTestCase): self.assertAllEqual(slc_as_dense, math_ops.add_n([slc])) self.assertAllEqual(2 * slc_as_dense, math_ops.add_n([slc, slc])) + def test_iterable(self): + """Test that add_n supports iterables (e.g. generators and dict values).""" + def fn(): + yield 1 + yield 2 + values_dict = {"a": 1, "b": 2} + with test_util.use_gpu(): + self.assertAllEqual(3, math_ops.add_n(fn())) + self.assertAllEqual(3, math_ops.add_n(values_dict.values())) + @test_util.run_all_in_graph_and_eager_modes class DivAndModTest(test_util.TensorFlowTestCase):