diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index addcd96d036..47c6f6b3225 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -352,20 +352,6 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase): self.assertIsNone(acc.jvp(v)) self.assertAllClose([[0.]], acc.jvp(v, unconnected_gradients="zero")) - @test_util.assert_no_new_pyobjects_executing_eagerly - def testFunctionReturnsResource(self): - v = variables.Variable([[1.]]) - x = constant_op.constant(1.) - xt = constant_op.constant(2.) - - @def_function.function - def f(a): - return a, v.handle - - with forwardprop.ForwardAccumulator(x, xt) as acc: - y, _ = f(x) - self.assertAllClose(2., acc.jvp(y)) - @test_util.assert_no_new_pyobjects_executing_eagerly def testMultipleWatchesAdd(self): x = constant_op.constant(-2.) diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 1562969dabe..ccf3bde7c14 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -2086,14 +2086,6 @@ bool ListContainsNone(PyObject* list) { return false; } -// As an optimization, the tape generally keeps only the shape and dtype of -// tensors, and uses this information to generate ones/zeros tensors. However, -// some tensors require OnesLike/ZerosLike because their gradients do not match -// their inference shape/dtype. -bool DTypeNeedsHandleData(tensorflow::DataType dtype) { - return dtype == tensorflow::DT_VARIANT || dtype == tensorflow::DT_RESOURCE; -} - static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { tensorflow::ImmediateExecutionTensorHandle* handle = @@ -2101,7 +2093,7 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { tensorflow::int64 id = PyEagerTensor_ID(tensor); tensorflow::DataType dtype = static_cast(handle->DataType()); - if (DTypeNeedsHandleData(dtype)) { + if (dtype == tensorflow::DT_VARIANT) { return PyTapeTensor(id, dtype, tensor); } @@ -2147,7 +2139,7 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { tensorflow::TensorShape({})); } - if (ListContainsNone(shape_tuple.get()) || DTypeNeedsHandleData(dtype)) { + if (ListContainsNone(shape_tuple.get()) || dtype == tensorflow::DT_VARIANT) { return PyTapeTensor(id, dtype, tensor); }