Raise error early if Tensor is not iterator.

Now iter(tensor) would raise an error if tensor is not iterable. Earlier, iter(tensor) never raised an error but next(iter(tensor)) could.

PiperOrigin-RevId: 293717639
Change-Id: Ic01280ff174478b4dbd2954163f9e76d8ed00d02
This commit is contained in:
Saurabh Saxena 2020-02-06 17:47:12 -08:00 committed by TensorFlower Gardener
parent 6f34a6c3e9
commit 1b604e5c34
3 changed files with 29 additions and 3 deletions

View File

@ -272,6 +272,12 @@ class TFETensorTest(test_util.TensorFlowTestCase):
for list_element, tensor_element in zip(l, t):
self.assertAllEqual(list_element, tensor_element.numpy())
def testIterateOverScalarTensorRaises(self):
t = _create_tensor(1)
with self.assertRaisesRegexp(TypeError,
"Cannot iterate over a scalar tensor"):
iter(t)
@test_util.run_gpu_only
def testStringTensorOnGPU(self):
with ops.device("/device:GPU:0"):

View File

@ -569,8 +569,7 @@ class Tensor(_TensorLike):
if shape[0] is None:
raise TypeError(
"Cannot iterate over a tensor with unknown first dimension.")
for i in xrange(shape[0]):
yield self[i]
return _TensorIterator(self, shape[0])
def _shape_as_list(self):
if self.shape.ndims is not None:
@ -6705,3 +6704,24 @@ def _reconstruct_sequence_inputs(op_def, inputs, attrs):
assert i == len(inputs)
return grouped_inputs
class _TensorIterator(object):
"""Iterates over the leading dim of a Tensor. Performs no error checks."""
def __init__(self, tensor, dim0):
self._tensor = tensor
self._index = 0
self._limit = dim0
def __iter__(self):
return self
def __next__(self):
if self._index == self._limit:
raise StopIteration
result = self._tensor[self._index]
self._index += 1
return result
next = __next__ # python2.x compatibility.

View File

@ -107,7 +107,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
t = op.outputs[0]
with self.assertRaisesRegexp(TypeError, "Cannot iterate"):
next(iter(t))
iter(t)
def testIterableGraph(self):
if context.executing_eagerly():