From 35b5b4228de7d4a701307153c46fb1bb671bbf45 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Feb 2021 12:45:32 -0800 Subject: [PATCH] Use OnesLike instead of Ones(shape, dtype) for resource tensors returned from functions (ForwardAccumulator+GradientTape) PiperOrigin-RevId: 359589786 Change-Id: Idaca1b869a4fada34a975dbcbe74f212ea037c0c --- tensorflow/python/eager/forwardprop_test.py | 14 -------------- tensorflow/python/eager/pywrap_tfe_src.cc | 12 ++---------- 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index addcd96d036..47c6f6b3225 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -352,20 +352,6 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): self.assertIsNone(acc.jvp(v)) self.assertAllClose([[0.]], acc.jvp(v, unconnected_gradients="zero")) - @test_util.assert_no_new_pyobjects_executing_eagerly - def testFunctionReturnsResource(self): - v = variables.Variable([[1.]]) - x = constant_op.constant(1.) - xt = constant_op.constant(2.) - - @def_function.function - def f(a): - return a, v.handle - - with forwardprop.ForwardAccumulator(x, xt) as acc: - y, _ = f(x) - self.assertAllClose(2., acc.jvp(y)) - @test_util.assert_no_new_pyobjects_executing_eagerly def testMultipleWatchesAdd(self): x = constant_op.constant(-2.) diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 1562969dabe..ccf3bde7c14 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -2086,14 +2086,6 @@ bool ListContainsNone(PyObject* list) { return false; } -// As an optimization, the tape generally keeps only the shape and dtype of -// tensors, and uses this information to generate ones/zeros tensors. However, -// some tensors require OnesLike/ZerosLike because their gradients do not match -// their inference shape/dtype. -bool DTypeNeedsHandleData(tensorflow::DataType dtype) { - return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE; -} - static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { tensorflow::ImmediateExecutionTensorHandle* handle = @@ -2101,7 +2093,7 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { tensorflow::int64 id = PyEagerTensor_ID(tensor); tensorflow::DataType dtype = static_cast(handle->DataType()); - if (DTypeNeedsHandleData(dtype)) { + if (dtype == tensorflow::DT_VARIANT) { return PyTapeTensor(id, dtype, tensor); } @@ -2147,7 +2139,7 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { tensorflow::TensorShape({})); } - if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) { + if (ListContainsNone(shape_tuple.get()) || dtype == tensorflow::DT_VARIANT) { return PyTapeTensor(id, dtype, tensor); }