From 1b604e5c34ded796e030ee79aa3e0d81c6351ad5 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 6 Feb 2020 17:47:12 -0800 Subject: [PATCH] Raise error early if Tensor is not iterator. Now iter(tensor) would raise an error if tensor is not iterable. Earlier, iter(tensor) never raised an error but next(iter(tensor)) could. PiperOrigin-RevId: 293717639 Change-Id: Ic01280ff174478b4dbd2954163f9e76d8ed00d02 --- tensorflow/python/eager/tensor_test.py | 6 ++++++ tensorflow/python/framework/ops.py | 24 ++++++++++++++++++++++-- tensorflow/python/framework/ops_test.py | 2 +- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index fd961671b52..342bd37eea5 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -272,6 +272,12 @@ class TFETensorTest(test_util.TensorFlowTestCase): for list_element, tensor_element in zip(l, t): self.assertAllEqual(list_element, tensor_element.numpy()) + def testIterateOverScalarTensorRaises(self): + t = _create_tensor(1) + with self.assertRaisesRegexp(TypeError, + "Cannot iterate over a scalar tensor"): + iter(t) + @test_util.run_gpu_only def testStringTensorOnGPU(self): with ops.device("/device:GPU:0"): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e92e56671b7..053d34c8da6 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -569,8 +569,7 @@ class Tensor(_TensorLike): if shape[0] is None: raise TypeError( "Cannot iterate over a tensor with unknown first dimension.") - for i in xrange(shape[0]): - yield self[i] + return _TensorIterator(self, shape[0]) def _shape_as_list(self): if self.shape.ndims is not None: @@ -6705,3 +6704,24 @@ def _reconstruct_sequence_inputs(op_def, inputs, attrs): assert i == len(inputs) return grouped_inputs + + +class _TensorIterator(object): + """Iterates over the leading dim of a Tensor. Performs no error checks.""" + + def __init__(self, tensor, dim0): + self._tensor = tensor + self._index = 0 + self._limit = dim0 + + def __iter__(self): + return self + + def __next__(self): + if self._index == self._limit: + raise StopIteration + result = self._tensor[self._index] + self._index += 1 + return result + + next = __next__ # python2.x compatibility. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 1a7410ffa76..dbd5abf7c5c 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -107,7 +107,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) t = op.outputs[0] with self.assertRaisesRegexp(TypeError, "Cannot iterate"): - next(iter(t)) + iter(t) def testIterableGraph(self): if context.executing_eagerly():