[tf.data] Prevent mutation of numpy arrays returned through as_numpy_iterator.

PiperOrigin-RevId: 355551968
Change-Id: I3ce0c4f014ab47088d44d9f75013b3b63b3bce76
This commit is contained in:
Andrew Audibert 2021-02-03 21:59:54 -08:00 committed by TensorFlower Gardener
parent d1012c9ff7
commit a9e9b8aae2
2 changed files with 11 additions and 1 deletions

View File

@ -38,6 +38,14 @@ class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
ds = dataset_ops.Dataset.range(3)
self.assertEqual([0, 1, 2], list(ds.as_numpy_iterator()))
@combinations.generate(test_base.eager_only_combinations())
def testImmutable(self):
ds = dataset_ops.Dataset.from_tensors([1, 2, 3])
arr = next(ds.as_numpy_iterator())
with self.assertRaisesRegex(ValueError,
'assignment destination is read-only'):
arr[0] = 0
@combinations.generate(test_base.eager_only_combinations())
def testNestedStructure(self):
point = collections.namedtuple('Point', ['x', 'y'])

View File

@ -4004,7 +4004,9 @@ class _NumpyIterator(object):
def to_numpy(x):
numpy = x._numpy() # pylint: disable=protected-access
if isinstance(numpy, np.ndarray):
return np.asarray(memoryview(numpy))
# `numpy` shares the same underlying buffer as the `x` Tensor.
# Tensors are expected to be immutable, so we disable writes.
numpy.setflags(write=False)
return numpy
return nest.map_structure(to_numpy, next(self._iterator))