Use OnesLike instead of Ones(shape, dtype) for resource tensors returned from functions (ForwardAccumulator+GradientTape)
PiperOrigin-RevId: 359589786 Change-Id: Idaca1b869a4fada34a975dbcbe74f212ea037c0c
This commit is contained in:
parent
fce0a27c38
commit
35b5b4228d
@ -352,20 +352,6 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsNone(acc.jvp(v))
|
||||
self.assertAllClose([[0.]], acc.jvp(v, unconnected_gradients="zero"))
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testFunctionReturnsResource(self):
|
||||
v = variables.Variable([[1.]])
|
||||
x = constant_op.constant(1.)
|
||||
xt = constant_op.constant(2.)
|
||||
|
||||
@def_function.function
|
||||
def f(a):
|
||||
return a, v.handle
|
||||
|
||||
with forwardprop.ForwardAccumulator(x, xt) as acc:
|
||||
y, _ = f(x)
|
||||
self.assertAllClose(2., acc.jvp(y))
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testMultipleWatchesAdd(self):
|
||||
x = constant_op.constant(-2.)
|
||||
|
@ -2086,14 +2086,6 @@ bool ListContainsNone(PyObject* list) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// As an optimization, the tape generally keeps only the shape and dtype of
|
||||
// tensors, and uses this information to generate ones/zeros tensors. However,
|
||||
// some tensors require OnesLike/ZerosLike because their gradients do not match
|
||||
// their inference shape/dtype.
|
||||
bool DTypeNeedsHandleData(tensorflow::DataType dtype) {
|
||||
return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE;
|
||||
}
|
||||
|
||||
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
|
||||
if (EagerTensor_CheckExact(tensor)) {
|
||||
tensorflow::ImmediateExecutionTensorHandle* handle =
|
||||
@ -2101,7 +2093,7 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
|
||||
tensorflow::int64 id = PyEagerTensor_ID(tensor);
|
||||
tensorflow::DataType dtype =
|
||||
static_cast<tensorflow::DataType>(handle->DataType());
|
||||
if (DTypeNeedsHandleData(dtype)) {
|
||||
if (dtype == tensorflow::DT_VARIANT) {
|
||||
return PyTapeTensor(id, dtype, tensor);
|
||||
}
|
||||
|
||||
@ -2147,7 +2139,7 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
|
||||
tensorflow::TensorShape({}));
|
||||
}
|
||||
|
||||
if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) {
|
||||
if (ListContainsNone(shape_tuple.get()) || dtype == tensorflow::DT_VARIANT) {
|
||||
return PyTapeTensor(id, dtype, tensor);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user