Support Iterables in ops.convert_n_to_tensor_or_indexed_slices and math_ops.add_n.
Fixes a tf.add_n papercut where it cannot be used with common iterable types beyond list/tuple. It only iterates the collection once, so it need not be any more specific than requiring an iterable. PiperOrigin-RevId: 314477113 Change-Id: I9bc1d2d3b2296e72f085d23481738df616b0ec6c
This commit is contained in:
parent
bd20260350
commit
d4bcf76529
@ -327,8 +327,8 @@ def internal_convert_n_to_tensor_or_indexed_slices(values,
|
||||
unmodified.
|
||||
|
||||
Args:
|
||||
values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
|
||||
can be consumed by `convert_to_tensor()`.
|
||||
values: An iterable of `None`, `IndexedSlices`, `SparseTensor`, or objects
|
||||
that can be consumed by `convert_to_tensor()`.
|
||||
dtype: (Optional.) The required `DType` of the returned `Tensor` or
|
||||
`IndexedSlices`.
|
||||
name: (Optional.) A name prefix to used when a new `Tensor` is created, in
|
||||
@ -344,8 +344,8 @@ def internal_convert_n_to_tensor_or_indexed_slices(values,
|
||||
RuntimeError: If a registered conversion function returns an invalid
|
||||
value.
|
||||
"""
|
||||
if not isinstance(values, collections.Sequence):
|
||||
raise TypeError("values must be a sequence.")
|
||||
if not isinstance(values, collections.Iterable):
|
||||
raise TypeError("values must be iterable.")
|
||||
ret = []
|
||||
for i, value in enumerate(values):
|
||||
if value is None:
|
||||
|
@ -70,6 +70,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import builtins
|
||||
@ -3421,12 +3423,12 @@ def add_n(inputs, name=None):
|
||||
ValueError: If `inputs` don't all have same shape and dtype or the shape
|
||||
cannot be inferred.
|
||||
"""
|
||||
if not inputs or not isinstance(inputs, (list, tuple)):
|
||||
raise ValueError("inputs must be a list of at least one "
|
||||
if not inputs or not isinstance(inputs, collections.Iterable):
|
||||
raise ValueError("inputs must be an iterable of at least one "
|
||||
"Tensor/IndexedSlices with the same dtype and shape")
|
||||
inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs)
|
||||
if not all(isinstance(x, (ops.Tensor, ops.IndexedSlices)) for x in inputs):
|
||||
raise ValueError("inputs must be a list of at least one "
|
||||
raise ValueError("inputs must be an iterable of at least one "
|
||||
"Tensor/IndexedSlices with the same dtype and shape")
|
||||
|
||||
if len(inputs) == 1:
|
||||
|
@ -423,6 +423,16 @@ class AddNTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(slc_as_dense, math_ops.add_n([slc]))
|
||||
self.assertAllEqual(2 * slc_as_dense, math_ops.add_n([slc, slc]))
|
||||
|
||||
def test_iterable(self):
|
||||
"""Test that add_n supports iterables (e.g. generators and dict values)."""
|
||||
def fn():
|
||||
yield 1
|
||||
yield 2
|
||||
values_dict = {"a": 1, "b": 2}
|
||||
with test_util.use_gpu():
|
||||
self.assertAllEqual(3, math_ops.add_n(fn()))
|
||||
self.assertAllEqual(3, math_ops.add_n(values_dict.values()))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DivAndModTest(test_util.TensorFlowTestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user