Make AbstractTensorHandle RefCounted so that we can use it for refcounting tensors under a GradientTape.
PiperOrigin-RevId: 322677493 Change-Id: I054d6127d6ec159be197f524ee6190c2537b1662
This commit is contained in:
parent
f9dd45444a
commit
ac9c8e2db5
@ -262,6 +262,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,11 +18,12 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
// execution mode.
|
||||
class AbstractTensorHandle {
|
||||
class AbstractTensorHandle : public core::RefCounted {
|
||||
protected:
|
||||
enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt };
|
||||
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
|
||||
@ -34,14 +35,6 @@ class AbstractTensorHandle {
|
||||
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
private:
|
||||
const AbstractTensorHandleKind kind_;
|
||||
};
|
||||
@ -50,7 +43,7 @@ namespace internal {
|
||||
struct AbstractTensorHandleDeleter {
|
||||
void operator()(AbstractTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
p->Unref();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -147,7 +147,7 @@ TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); }
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); }
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); }
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); }
|
||||
|
@ -49,7 +49,6 @@ class GraphTensor : public TracingTensorHandle {
|
||||
public:
|
||||
explicit GraphTensor(TF_Output output)
|
||||
: TracingTensorHandle(kGraph), output_(output) {}
|
||||
void Release() override { delete this; }
|
||||
|
||||
tensorflow::DataType DataType() const override {
|
||||
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
|
||||
|
@ -51,25 +51,14 @@ int64 ToId(AbstractTensorHandle* t) {
|
||||
|
||||
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
|
||||
: handle_(handle), ctx_(ctx) {
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Ref();
|
||||
handle_->Ref();
|
||||
}
|
||||
TapeTensor::TapeTensor(const TapeTensor& other) {
|
||||
handle_ = other.handle_;
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Ref();
|
||||
handle_->Ref();
|
||||
ctx_ = other.ctx_;
|
||||
}
|
||||
TapeTensor::~TapeTensor() {
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Unref();
|
||||
}
|
||||
TapeTensor::~TapeTensor() { handle_->Unref(); }
|
||||
|
||||
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
|
||||
|
||||
@ -192,7 +181,7 @@ TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
|
||||
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
|
||||
|
||||
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
|
||||
gradient->Release();
|
||||
gradient->Unref();
|
||||
}
|
||||
|
||||
// Helper functions which delegate to `AbstractOperation`, update
|
||||
|
@ -93,7 +93,7 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Release();
|
||||
add_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
@ -144,14 +144,14 @@ Status RunModel(Model model, AbstractContext* ctx,
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(output_list.outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Release();
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
output_list.outputs[0]->Release();
|
||||
output_list.outputs[1]->Release();
|
||||
output_list.outputs[0]->Unref();
|
||||
output_list.outputs[1]->Unref();
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[0]->Release();
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
@ -260,7 +260,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[1]->Release();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
@ -270,7 +270,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*tfrt*/ ::testing::Values(true, false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
|
@ -50,6 +50,14 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
||||
// Return a copy of the handle.
|
||||
virtual ImmediateExecutionTensorHandle* Copy() = 0;
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractTensorHandle* ptr) {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
|
@ -102,8 +102,6 @@ class MlirTensor : public TracingTensorHandle {
|
||||
return type;
|
||||
}
|
||||
|
||||
void Release() override { delete this; }
|
||||
|
||||
Value getValue() { return value_; }
|
||||
|
||||
// For LLVM style RTTI.
|
||||
|
@ -53,8 +53,7 @@ class EagerContext;
|
||||
// Associates a Tensor and a Device, used in the eager runtime. Internal version
|
||||
// of the TFE_TensorHandle struct and the python EagerTensor class
|
||||
// (unrelated to python TensorHandle).
|
||||
class TensorHandle : public ImmediateExecutionTensorHandle,
|
||||
public core::RefCounted {
|
||||
class TensorHandle : public ImmediateExecutionTensorHandle {
|
||||
// TensorHandle for dtype != DT_RESOURCE
|
||||
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
||||
Device* resource_device, EagerContext* ctx);
|
||||
|
Loading…
Reference in New Issue
Block a user