Use OnesLike instead of Ones(shape, dtype) for resource tensors returned from functions (ForwardAccumulator+GradientTape)

PiperOrigin-RevId: 359589786
Change-Id: Idaca1b869a4fada34a975dbcbe74f212ea037c0c
This commit is contained in:
A. Unique TensorFlower 2021-02-25 12:45:32 -08:00 committed by TensorFlower Gardener
parent fce0a27c38
commit 35b5b4228d
2 changed files with 2 additions and 24 deletions

View File

@ -352,20 +352,6 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
self.assertIsNone(acc.jvp(v))
self.assertAllClose([[0.]], acc.jvp(v, unconnected_gradients="zero"))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFunctionReturnsResource(self):
v = variables.Variable([[1.]])
x = constant_op.constant(1.)
xt = constant_op.constant(2.)
@def_function.function
def f(a):
return a, v.handle
with forwardprop.ForwardAccumulator(x, xt) as acc:
y, _ = f(x)
self.assertAllClose(2., acc.jvp(y))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testMultipleWatchesAdd(self):
x = constant_op.constant(-2.)

View File

@ -2086,14 +2086,6 @@ bool ListContainsNone(PyObject* list) {
return false;
}
// As an optimization, the tape generally keeps only the shape and dtype of
// tensors, and uses this information to generate ones/zeros tensors. However,
// some tensors require OnesLike/ZerosLike because their gradients do not match
// their inference shape/dtype.
bool DTypeNeedsHandleData(tensorflow::DataType dtype) {
return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE;
}
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
tensorflow::ImmediateExecutionTensorHandle* handle =
@ -2101,7 +2093,7 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(handle->DataType());
if (DTypeNeedsHandleData(dtype)) {
if (dtype == tensorflow::DT_VARIANT) {
return PyTapeTensor(id, dtype, tensor);
}
@ -2147,7 +2139,7 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
tensorflow::TensorShape({}));
}
if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) {
if (ListContainsNone(shape_tuple.get()) || dtype == tensorflow::DT_VARIANT) {
return PyTapeTensor(id, dtype, tensor);
}