[tf.data] Prevent mutation of numpy arrays returned through as_numpy_iterator.
PiperOrigin-RevId: 355551968 Change-Id: I3ce0c4f014ab47088d44d9f75013b3b63b3bce76
This commit is contained in:
parent
d1012c9ff7
commit
a9e9b8aae2
@ -38,6 +38,14 @@ class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
ds = dataset_ops.Dataset.range(3)
|
||||
self.assertEqual([0, 1, 2], list(ds.as_numpy_iterator()))
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testImmutable(self):
|
||||
ds = dataset_ops.Dataset.from_tensors([1, 2, 3])
|
||||
arr = next(ds.as_numpy_iterator())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'assignment destination is read-only'):
|
||||
arr[0] = 0
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testNestedStructure(self):
|
||||
point = collections.namedtuple('Point', ['x', 'y'])
|
||||
|
@ -4004,7 +4004,9 @@ class _NumpyIterator(object):
|
||||
def to_numpy(x):
|
||||
numpy = x._numpy() # pylint: disable=protected-access
|
||||
if isinstance(numpy, np.ndarray):
|
||||
return np.asarray(memoryview(numpy))
|
||||
# `numpy` shares the same underlying buffer as the `x` Tensor.
|
||||
# Tensors are expected to be immutable, so we disable writes.
|
||||
numpy.setflags(write=False)
|
||||
return numpy
|
||||
|
||||
return nest.map_structure(to_numpy, next(self._iterator))
|
||||
|
Loading…
x
Reference in New Issue
Block a user