diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index dca35b78e0f..61701bc8b21 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -262,6 +262,7 @@ cc_library( ], deps = [ "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:refcount", ], ) diff --git a/tensorflow/c/eager/abstract_tensor_handle.h b/tensorflow/c/eager/abstract_tensor_handle.h index de041690420..37e6d1bf29c 100644 --- a/tensorflow/c/eager/abstract_tensor_handle.h +++ b/tensorflow/c/eager/abstract_tensor_handle.h @@ -18,11 +18,12 @@ limitations under the License. #include #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/refcount.h" namespace tensorflow { // Abstract interface to a Tensor handle in either tracing or immediate // execution mode. -class AbstractTensorHandle { +class AbstractTensorHandle : public core::RefCounted { protected: enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt }; explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {} @@ -34,14 +35,6 @@ class AbstractTensorHandle { AbstractTensorHandleKind getKind() const { return kind_; } - // Release any underlying resources, including the interface object. - // - // WARNING: The destructor of this class is marked as protected to disallow - // clients from directly destroying this object since it may manage it's own - // lifetime through ref counting. Thus this must be allocated on the heap and - // clients MUST call Release() in order to destroy an instance of this class. - virtual void Release() = 0; - private: const AbstractTensorHandleKind kind_; }; @@ -50,7 +43,7 @@ namespace internal { struct AbstractTensorHandleDeleter { void operator()(AbstractTensorHandle* p) const { if (p != nullptr) { - p->Release(); + p->Unref(); } } }; diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 605a60c186c..8408f7ef60f 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -147,7 +147,7 @@ TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) { void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); } -void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); } +void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); } TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); } void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); } diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 6c903560e52..7bda3aed76d 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -49,7 +49,6 @@ class GraphTensor : public TracingTensorHandle { public: explicit GraphTensor(TF_Output output) : TracingTensorHandle(kGraph), output_(output) {} - void Release() override { delete this; } tensorflow::DataType DataType() const override { return static_cast(TF_OperationOutputType(output_)); diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index f5085fdb926..cf62dcea926 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -51,25 +51,14 @@ int64 ToId(AbstractTensorHandle* t) { TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx) : handle_(handle), ctx_(ctx) { - // TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely - // on the client to keep this tensor live for the duration of the gradient - // computation. - // handle_->Ref(); + handle_->Ref(); } TapeTensor::TapeTensor(const TapeTensor& other) { handle_ = other.handle_; - // TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely - // on the client to keep this tensor live for the duration of the gradient - // computation. - // handle_->Ref(); + handle_->Ref(); ctx_ = other.ctx_; } -TapeTensor::~TapeTensor() { - // TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely - // on the client to keep this tensor live for the duration of the gradient - // computation. - // handle_->Unref(); -} +TapeTensor::~TapeTensor() { handle_->Unref(); } tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); } @@ -192,7 +181,7 @@ TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const { void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {} void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const { - gradient->Release(); + gradient->Unref(); } // Helper functions which delegate to `AbstractOperation`, update diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index 0a3d267e937..e02f189c3d2 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -93,7 +93,7 @@ Status AddGradModel(AbstractContext* ctx, source_tensors_that_are_targets, /*output_gradients=*/{}, &out_grads)); for (auto add_output : add_outputs) { - add_output->Release(); + add_output->Unref(); } outputs[0] = out_grads[0]; outputs[1] = out_grads[1]; @@ -144,14 +144,14 @@ Status RunModel(Model model, AbstractContext* ctx, TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), absl::MakeSpan(output_list.outputs), registry)); for (auto func_input : func_inputs) { - func_input->Release(); + func_input->Unref(); } AbstractFunction* func = nullptr; TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) ->Finalize(&output_list, &func)); scoped_func.reset(func); - output_list.outputs[0]->Release(); - output_list.outputs[1]->Release(); + output_list.outputs[0]->Unref(); + output_list.outputs[1]->Unref(); TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); } @@ -252,7 +252,7 @@ TEST_P(CppGradients, TestAddGrad) { ASSERT_EQ(errors::OK, s.code()) << s.error_message(); auto result_value = static_cast(TF_TensorData(result_tensor)); EXPECT_EQ(*result_value, 1.0); - outputs[0]->Release(); + outputs[0]->Unref(); TF_DeleteTensor(result_tensor); result_tensor = nullptr; @@ -260,7 +260,7 @@ TEST_P(CppGradients, TestAddGrad) { ASSERT_EQ(errors::OK, s.code()) << s.error_message(); result_value = static_cast(TF_TensorData(result_tensor)); EXPECT_EQ(*result_value, 1.0); - outputs[1]->Release(); + outputs[1]->Unref(); TF_DeleteTensor(result_tensor); } @@ -270,7 +270,7 @@ TEST_P(CppGradients, TestAddGrad) { INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, ::testing::Combine(::testing::Values("graphdef"), - /*tfrt*/ ::testing::Values(false), + /*tfrt*/ ::testing::Values(true, false), /*executing_eagerly*/ ::testing::Values(true, false))); #else INSTANTIATE_TEST_SUITE_P( diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h index f7c77aa06db..6d32d482747 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -50,6 +50,14 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle { // Return a copy of the handle. virtual ImmediateExecutionTensorHandle* Copy() = 0; + // Release any underlying resources, including the interface object. + // + // WARNING: The destructor of this class is marked as protected to disallow + // clients from directly destroying this object since it may manage it's own + // lifetime through ref counting. Thus this must be allocated on the heap and + // clients MUST call Release() in order to destroy an instance of this class. + virtual void Release() = 0; + // For LLVM style RTTI. static bool classof(const AbstractTensorHandle* ptr) { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index ffd9c149d2d..51890c1e9ee 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -102,8 +102,6 @@ class MlirTensor : public TracingTensorHandle { return type; } - void Release() override { delete this; } - Value getValue() { return value_; } // For LLVM style RTTI. diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 007ba33f231..99f88fe886a 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -53,8 +53,7 @@ class EagerContext; // Associates a Tensor and a Device, used in the eager runtime. Internal version // of the TFE_TensorHandle struct and the python EagerTensor class // (unrelated to python TensorHandle). -class TensorHandle : public ImmediateExecutionTensorHandle, - public core::RefCounted { +class TensorHandle : public ImmediateExecutionTensorHandle { // TensorHandle for dtype != DT_RESOURCE TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, Device* resource_device, EagerContext* ctx);