Make AbstractTensorHandle RefCounted so that we can use it for refcounting tensors under a GradientTape.

PiperOrigin-RevId: 322677493
Change-Id: I054d6127d6ec159be197f524ee6190c2537b1662
This commit is contained in:
Saurabh Saxena 2020-07-22 16:16:08 -07:00 committed by TensorFlower Gardener
parent f9dd45444a
commit ac9c8e2db5
9 changed files with 25 additions and 38 deletions

View File

@ -262,6 +262,7 @@ cc_library(
],
deps = [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:refcount",
],
)

View File

@ -18,11 +18,12 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/refcount.h"
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
// execution mode.
class AbstractTensorHandle {
class AbstractTensorHandle : public core::RefCounted {
protected:
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
@ -34,14 +35,6 @@ class AbstractTensorHandle {
AbstractTensorHandleKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
private:
const AbstractTensorHandleKind kind_;
};
@ -50,7 +43,7 @@ namespace internal {
struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandle* p) const {
if (p != nullptr) {
p->Release();
p->Unref();
}
}
};

View File

@ -147,7 +147,7 @@ TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); }
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); }
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }

View File

@ -49,7 +49,6 @@ class GraphTensor : public TracingTensorHandle {
public:
explicit GraphTensor(TF_Output output)
: TracingTensorHandle(kGraph), output_(output) {}
void Release() override { delete this; }
tensorflow::DataType DataType() const override {
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));

View File

@ -51,25 +51,14 @@ int64 ToId(AbstractTensorHandle* t) {
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
: handle_(handle), ctx_(ctx) {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
handle_->Ref();
}
TapeTensor::TapeTensor(const TapeTensor& other) {
handle_ = other.handle_;
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
handle_->Ref();
ctx_ = other.ctx_;
}
TapeTensor::~TapeTensor() {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Unref();
}
TapeTensor::~TapeTensor() { handle_->Unref(); }
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
@ -192,7 +181,7 @@ TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
gradient->Release();
gradient->Unref();
}
// Helper functions which delegate to `AbstractOperation`, update

View File

@ -93,7 +93,7 @@ Status AddGradModel(AbstractContext* ctx,
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
for (auto add_output : add_outputs) {
add_output->Release();
add_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
@ -144,14 +144,14 @@ Status RunModel(Model model, AbstractContext* ctx,
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(output_list.outputs), registry));
for (auto func_input : func_inputs) {
func_input->Release();
func_input->Unref();
}
AbstractFunction* func = nullptr;
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
output_list.outputs[0]->Release();
output_list.outputs[1]->Release();
output_list.outputs[0]->Unref();
output_list.outputs[1]->Unref();
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
@ -252,7 +252,7 @@ TEST_P(CppGradients, TestAddGrad) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[0]->Release();
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
@ -260,7 +260,7 @@ TEST_P(CppGradients, TestAddGrad) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Release();
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
}
@ -270,7 +270,7 @@ TEST_P(CppGradients, TestAddGrad) {
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*tfrt*/ ::testing::Values(true, false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(

View File

@ -50,6 +50,14 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
// Return a copy of the handle.
virtual ImmediateExecutionTensorHandle* Copy() = 0;
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;

View File

@ -102,8 +102,6 @@ class MlirTensor : public TracingTensorHandle {
return type;
}
void Release() override { delete this; }
Value getValue() { return value_; }
// For LLVM style RTTI.

View File

@ -53,8 +53,7 @@ class EagerContext;
// Associates a Tensor and a Device, used in the eager runtime. Internal version
// of the TFE_TensorHandle struct and the python EagerTensor class
// (unrelated to python TensorHandle).
class TensorHandle : public ImmediateExecutionTensorHandle,
public core::RefCounted {
class TensorHandle : public ImmediateExecutionTensorHandle {
// TensorHandle for dtype != DT_RESOURCE
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
Device* resource_device, EagerContext* ctx);