Disable __setitem__ on tf numpy's ndarray for experimental release
PiperOrigin-RevId: 317604521 Change-Id: I19c6a78125fcd29109e45a66099e7325f7136fdf
This commit is contained in:
parent
5229c77d94
commit
072cf7ee4b
@ -65,16 +65,6 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "np_backprop_test",
|
||||
srcs = ["np_backprop_test.py"],
|
||||
deps = [
|
||||
":numpy",
|
||||
"//tensorflow/python:platform",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "np_logic_test",
|
||||
srcs = ["np_logic_test.py"],
|
||||
|
@ -852,66 +852,10 @@ def moveaxis(a, source, destination): # pylint: disable=missing-docstring
|
||||
return np_utils.tensor_to_ndarray(a)
|
||||
|
||||
|
||||
def _setitem(arr, index, value):
|
||||
"""Sets the `value` at `index` in the array `arr`.
|
||||
|
||||
This works by replacing the slice at `index` in the tensor with `value`.
|
||||
Since tensors are immutable, this builds a new tensor using the `tf.concat`
|
||||
op. Currently, only 0-d and 1-d indices are supported.
|
||||
|
||||
Note that this may break gradients e.g.
|
||||
|
||||
a = tf_np.array([1, 2, 3])
|
||||
old_a_t = a.data
|
||||
|
||||
with tf.GradientTape(persistent=True) as g:
|
||||
g.watch(a.data)
|
||||
b = a * 2
|
||||
a[0] = 5
|
||||
g.gradient(b.data, [a.data]) # [None]
|
||||
g.gradient(b.data, [old_a_t]) # [[2., 2., 2.]]
|
||||
|
||||
Here `d_b / d_a` is `[None]` since a.data no longer points to the same
|
||||
tensor.
|
||||
|
||||
Args:
|
||||
arr: array_like.
|
||||
index: scalar or 1-d integer array.
|
||||
value: value to set at index.
|
||||
|
||||
Returns:
|
||||
ndarray
|
||||
|
||||
Raises:
|
||||
ValueError: if `index` is not a scalar or 1-d array.
|
||||
"""
|
||||
# TODO(srbs): Figure out a solution to the gradient problem.
|
||||
arr = asarray(arr)
|
||||
index = asarray(index)
|
||||
if index.ndim == 0:
|
||||
index = ravel(index)
|
||||
elif index.ndim > 1:
|
||||
raise ValueError('index must be a scalar or a 1-d array.')
|
||||
value = asarray(value, dtype=arr.dtype)
|
||||
if arr.shape[len(index):] != value.shape:
|
||||
value = full(arr.shape[len(index):], value)
|
||||
prefix_t = arr.data[:index.data[0]]
|
||||
postfix_t = arr.data[index.data[0] + 1:]
|
||||
if len(index) == 1:
|
||||
arr._data = array_ops.concat( # pylint: disable=protected-access
|
||||
[prefix_t, array_ops.expand_dims(value.data, 0), postfix_t], 0)
|
||||
else:
|
||||
subarray = arr[index.data[0]]
|
||||
_setitem(subarray, index[1:], value)
|
||||
arr._data = array_ops.concat( # pylint: disable=protected-access
|
||||
[prefix_t, array_ops.expand_dims(subarray.data, 0), postfix_t], 0)
|
||||
|
||||
|
||||
# TODO(wangpeng): Make a custom `setattr` that also sets docstring for the
|
||||
# method.
|
||||
setattr(np_arrays.ndarray, 'transpose', transpose)
|
||||
setattr(np_arrays.ndarray, 'reshape', _reshape_method_wrapper)
|
||||
setattr(np_arrays.ndarray, '__setitem__', _setitem)
|
||||
|
||||
|
||||
@np_utils.np_doc('pad')
|
||||
|
@ -911,26 +911,6 @@ class ArrayMethodsTest(test.TestCase):
|
||||
run_test(np.arange(30).reshape(2, 3, 5).tolist(), [2, 0, 1])
|
||||
run_test(np.arange(30).reshape(2, 3, 5).tolist(), [2, 1, 0])
|
||||
|
||||
def testSetItem(self):
|
||||
|
||||
def run_test(arr, index, value):
|
||||
for fn in self.array_transforms:
|
||||
value_arg = fn(value)
|
||||
tf_array = np_array_ops.array(arr)
|
||||
np_array = np.array(arr)
|
||||
tf_array[index] = value_arg
|
||||
# TODO(srbs): "setting an array element with a sequence" is thrown
|
||||
# if we do not wrap value_arg in a numpy array. Investigate how this can
|
||||
# be avoided.
|
||||
np_array[index] = np.array(value_arg)
|
||||
self.match(tf_array, np_array)
|
||||
|
||||
run_test([1, 2, 3], 1, 5)
|
||||
run_test([[1, 2], [3, 4]], 0, [6, 7])
|
||||
run_test([[1, 2], [3, 4]], 1, [6, 7])
|
||||
run_test([[1, 2], [3, 4]], (0, 1), 6)
|
||||
run_test([[1, 2], [3, 4]], 0, 6) # Value needs to broadcast.
|
||||
|
||||
def match_shape(self, actual, expected, msg=None):
|
||||
if msg:
|
||||
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
|
||||
|
@ -1,69 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for backpropgration on tf-numpy functions."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops.numpy_ops import np_array_ops
|
||||
# Required for operator overloads
|
||||
from tensorflow.python.ops.numpy_ops import np_math_ops # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BackpropTest(test.TestCase):
|
||||
|
||||
def test_setitem(self):
|
||||
# Single integer index.
|
||||
a = np_array_ops.array([1., 2., 3.])
|
||||
b = np_array_ops.array(5.)
|
||||
c = np_array_ops.array(10.)
|
||||
|
||||
tensors = [arr.data for arr in [a, b, c]]
|
||||
with backprop.GradientTape() as g:
|
||||
g.watch(tensors)
|
||||
a[1] = b + c
|
||||
loss = np_array_ops.sum(a)
|
||||
|
||||
gradients = g.gradient(loss.data, tensors)
|
||||
self.assertSequenceEqual(
|
||||
np_array_ops.array(gradients[0]).tolist(), [1., 0., 1.])
|
||||
self.assertEqual(np_array_ops.array(gradients[1]).tolist(), 1.)
|
||||
self.assertEqual(np_array_ops.array(gradients[2]).tolist(), 1.)
|
||||
|
||||
# Tuple index.
|
||||
a = np_array_ops.array([[[1., 2.], [3., 4.]], [[5., 6.],
|
||||
[7., 8.]]]) # 2x2x2 array.
|
||||
b = np_array_ops.array([10., 11.])
|
||||
|
||||
tensors = [arr.data for arr in [a, b]]
|
||||
with backprop.GradientTape() as g:
|
||||
g.watch(tensors)
|
||||
a[(1, 0)] = b
|
||||
loss = np_array_ops.sum(a)
|
||||
|
||||
gradients = g.gradient(loss.data, tensors)
|
||||
self.assertSequenceEqual(
|
||||
np_array_ops.array(gradients[0]).tolist(),
|
||||
[[[1., 1.], [1., 1.]], [[0., 0.], [1., 1.]]])
|
||||
self.assertEqual(np_array_ops.array(gradients[1]).tolist(), [1., 1.])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user