Raise error early if Tensor is not iterator.
Now iter(tensor) would raise an error if tensor is not iterable. Earlier, iter(tensor) never raised an error but next(iter(tensor)) could. PiperOrigin-RevId: 293717639 Change-Id: Ic01280ff174478b4dbd2954163f9e76d8ed00d02
This commit is contained in:
parent
6f34a6c3e9
commit
1b604e5c34
@ -272,6 +272,12 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
|||||||
for list_element, tensor_element in zip(l, t):
|
for list_element, tensor_element in zip(l, t):
|
||||||
self.assertAllEqual(list_element, tensor_element.numpy())
|
self.assertAllEqual(list_element, tensor_element.numpy())
|
||||||
|
|
||||||
|
def testIterateOverScalarTensorRaises(self):
|
||||||
|
t = _create_tensor(1)
|
||||||
|
with self.assertRaisesRegexp(TypeError,
|
||||||
|
"Cannot iterate over a scalar tensor"):
|
||||||
|
iter(t)
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def testStringTensorOnGPU(self):
|
def testStringTensorOnGPU(self):
|
||||||
with ops.device("/device:GPU:0"):
|
with ops.device("/device:GPU:0"):
|
||||||
|
@ -569,8 +569,7 @@ class Tensor(_TensorLike):
|
|||||||
if shape[0] is None:
|
if shape[0] is None:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Cannot iterate over a tensor with unknown first dimension.")
|
"Cannot iterate over a tensor with unknown first dimension.")
|
||||||
for i in xrange(shape[0]):
|
return _TensorIterator(self, shape[0])
|
||||||
yield self[i]
|
|
||||||
|
|
||||||
def _shape_as_list(self):
|
def _shape_as_list(self):
|
||||||
if self.shape.ndims is not None:
|
if self.shape.ndims is not None:
|
||||||
@ -6705,3 +6704,24 @@ def _reconstruct_sequence_inputs(op_def, inputs, attrs):
|
|||||||
|
|
||||||
assert i == len(inputs)
|
assert i == len(inputs)
|
||||||
return grouped_inputs
|
return grouped_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class _TensorIterator(object):
|
||||||
|
"""Iterates over the leading dim of a Tensor. Performs no error checks."""
|
||||||
|
|
||||||
|
def __init__(self, tensor, dim0):
|
||||||
|
self._tensor = tensor
|
||||||
|
self._index = 0
|
||||||
|
self._limit = dim0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self._index == self._limit:
|
||||||
|
raise StopIteration
|
||||||
|
result = self._tensor[self._index]
|
||||||
|
self._index += 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
next = __next__ # python2.x compatibility.
|
||||||
|
@ -107,7 +107,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|||||||
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
|
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
|
||||||
t = op.outputs[0]
|
t = op.outputs[0]
|
||||||
with self.assertRaisesRegexp(TypeError, "Cannot iterate"):
|
with self.assertRaisesRegexp(TypeError, "Cannot iterate"):
|
||||||
next(iter(t))
|
iter(t)
|
||||||
|
|
||||||
def testIterableGraph(self):
|
def testIterableGraph(self):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
|
Loading…
Reference in New Issue
Block a user