Moving RAII helpers for TensorHandle, Tensor, and Operation to their respective classes.

PiperOrigin-RevId: 317578771
Change-Id: Iaf674696ea7d7dfdf94924f4c60d555a613c5f57
This commit is contained in:
Brian Zhao 2020-06-21 19:06:43 -07:00 committed by TensorFlower Gardener
parent f45d6083b7
commit ebf57bdfc7
16 changed files with 152 additions and 262 deletions

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#include <memory>
#include <vector>
#include "tensorflow/c/eager/abstract_function.h"
@ -64,6 +65,19 @@ class AbstractContext {
const AbstractContextKind kind_;
};
namespace internal {
struct AbstractContextDeleter {
void operator()(AbstractContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractContextPtr =
std::unique_ptr<AbstractContext, internal::AbstractContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_

View File

@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
@ -110,6 +112,19 @@ class AbstractOperation {
const AbstractOperationKind kind_;
};
namespace internal {
struct AbstractOperationDeleter {
void operator()(AbstractOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractOpPtr =
std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_

View File

@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#include <memory>
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
@ -40,6 +42,20 @@ class AbstractTensorHandle {
const AbstractTensorHandleKind kind_;
};
namespace internal {
struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractTensorHandlePtr =
std::unique_ptr<AbstractTensorHandle,
internal::AbstractTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_

View File

@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
@ -107,6 +108,20 @@ class ImmediateExecutionContext : public AbstractContext {
~ImmediateExecutionContext() override {}
};
namespace internal {
struct ImmediateExecutionContextDeleter {
void operator()(ImmediateExecutionContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateContextPtr =
std::unique_ptr<ImmediateExecutionContext,
internal::ImmediateExecutionContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_

View File

@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
@ -48,6 +50,20 @@ class ImmediateExecutionOperation : public AbstractOperation {
~ImmediateExecutionOperation() override {}
};
namespace internal {
struct ImmediateExecutionOperationDeleter {
void operator()(ImmediateExecutionOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateOpPtr =
std::unique_ptr<ImmediateExecutionOperation,
internal::ImmediateExecutionOperationDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_

View File

@ -59,6 +59,20 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
~ImmediateExecutionTensorHandle() override {}
};
namespace internal {
struct ImmediateExecutionTensorHandleDeleter {
void operator()(ImmediateExecutionTensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateTensorHandlePtr =
std::unique_ptr<ImmediateExecutionTensorHandle,
internal::ImmediateExecutionTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_

View File

@ -14,44 +14,6 @@ package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "owned_eager_op",
hdrs = [
"owned_eager_op.h",
],
deps = [
"//tensorflow/c/eager:immediate_execution_operation",
],
)
cc_library(
name = "owned_tensor_handle",
hdrs = [
"owned_tensor_handle.h",
],
deps = [
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
cc_library(
name = "owned_eager_context",
hdrs = ["owned_eager_context.h"],
deps = [
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/core/common_runtime/eager:context",
],
)
cc_library(
name = "owned_tensor",
hdrs = ["owned_tensor.h"],
deps = [
"//tensorflow/c:tensor_interface",
],
)
cc_library(
name = "variable_ops",
srcs = [
@ -61,10 +23,10 @@ cc_library(
"variable_ops.h",
],
deps = [
":owned_eager_op",
":owned_tensor_handle",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -79,11 +41,11 @@ tf_cc_test(
"variable_ops_test.cc",
],
deps = [
":owned_eager_context",
":owned_tensor",
":owned_tensor_handle",
":variable_ops",
"//tensorflow/core:all_kernels",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -1,54 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
#include <memory>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/core/common_runtime/eager/context.h"
namespace tensorflow {
namespace internal {
struct ImmediateExecutionContextDeleter {
void operator()(ImmediateExecutionContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
struct EagerContextDeleter {
void operator()(EagerContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractContextPtr =
std::unique_ptr<ImmediateExecutionContext,
internal::ImmediateExecutionContextDeleter>;
using EagerContextPtr =
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_

View File

@ -1,42 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
#include <memory>
#include "tensorflow/c/eager/immediate_execution_operation.h"
namespace tensorflow {
namespace internal {
struct ImmediateExecutionOperationDeleter {
void operator()(ImmediateExecutionOperation* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractOpPtr =
std::unique_ptr<ImmediateExecutionOperation,
internal::ImmediateExecutionOperationDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_

View File

@ -1,42 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
#include <memory>
#include "tensorflow/c/tensor_interface.h"
namespace tensorflow {
namespace internal {
struct AbstractTensorInterfaceDeleter {
void operator()(AbstractTensorInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractTensorPtr =
std::unique_ptr<AbstractTensorInterface,
internal::AbstractTensorInterfaceDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_

View File

@ -1,54 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
#include <memory>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
namespace tensorflow {
namespace internal {
struct TensorHandleDeleter {
void operator()(TensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
struct AbstractTensorHandleDeleter {
void operator()(ImmediateExecutionTensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using TensorHandlePtr =
std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>;
using AbstractTensorHandlePtr =
std::unique_ptr<ImmediateExecutionTensorHandle,
internal::AbstractTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_

View File

@ -16,10 +16,11 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@ -35,8 +36,8 @@ static const char kNoSharingResourceID[] =
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle) {
AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation());
ImmediateTensorHandlePtr* handle) {
ImmediateOpPtr varhandle_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
@ -55,17 +56,19 @@ Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
int num_retvals = 1;
TF_RETURN_IF_ERROR(varhandle_op->Execute(
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
if (var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) {
AbstractTensorHandlePtr owned_var_handle(var_handle);
if (owned_var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) {
return errors::Internal("Unexpected tensor handle kind.");
}
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(var_handle));
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
owned_var_handle.release()));
return Status();
}
Status AssignVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, ImmediateExecutionTensorHandle* value) {
AbstractOpPtr assign_op(ctx->CreateOperation());
ImmediateOpPtr assign_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
@ -78,8 +81,8 @@ Status AssignVariable(ImmediateExecutionContext* ctx,
Status ReadVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, AbstractTensorHandlePtr* output) {
AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation());
DataType dtype, ImmediateTensorHandlePtr* output) {
ImmediateOpPtr read_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
@ -88,16 +91,18 @@ Status ReadVariable(ImmediateExecutionContext* ctx,
int num_retvals = 1;
TF_RETURN_IF_ERROR(
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
if (value->getKind() != ImmediateExecutionTensorHandle::kKind) {
AbstractTensorHandlePtr owned_value(value);
if (owned_value->getKind() != ImmediateExecutionTensorHandle::kKind) {
return errors::Internal("Unexpected tensor handle kind.");
}
output->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(value));
output->reset(
reinterpret_cast<ImmediateExecutionTensorHandle*>(owned_value.release()));
return Status();
}
Status DestroyResource(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* handle) {
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
ImmediateOpPtr destroy_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
TF_RETURN_IF_ERROR(destroy_op->AddInput(handle));

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
@ -32,7 +31,7 @@ namespace internal {
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle);
ImmediateTensorHandlePtr* handle);
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
// with `variable_handle` with `value`. `dtype` must be the datatype of the
@ -48,7 +47,7 @@ Status AssignVariable(ImmediateExecutionContext* ctx,
// the dtype of the variable associated with `variable_handle`.
Status ReadVariable(ImmediateExecutionContext* ctx,
ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, AbstractTensorHandlePtr* output);
DataType dtype, ImmediateTensorHandlePtr* output);
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:

View File

@ -17,9 +17,8 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
@ -30,10 +29,10 @@ limitations under the License.
namespace tensorflow {
namespace {
AbstractTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
float value) {
ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
float value) {
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
AbstractTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
ImmediateTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
return handle;
}
@ -62,7 +61,7 @@ class VariableOpsTest : public ::testing::Test {
// Sanity check for variable creation
TEST_F(VariableOpsTest, CreateVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr handle;
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
// The created TensorHandle should be a DT_Resource
@ -72,7 +71,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) {
// Sanity check for variable destruction
TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr handle;
ImmediateTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
@ -83,18 +82,18 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
// Sanity check for handle assignment and reading
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr variable;
ImmediateTensorHandlePtr variable;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &variable));
// Create a Scalar float TensorHandle with value 42, and assign it to
// the variable.
AbstractTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
ImmediateTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
my_value.get()));
// Read back the value from the variable, and check that it is 42.
AbstractTensorHandlePtr read_value_handle;
ImmediateTensorHandlePtr read_value_handle;
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
&read_value_handle));
Status status;

View File

@ -54,6 +54,20 @@ class AbstractTensorInterface {
virtual ~AbstractTensorInterface() {}
};
namespace internal {
struct AbstractTensorInterfaceDeleter {
void operator()(AbstractTensorInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractTensorPtr =
std::unique_ptr<AbstractTensorInterface,
internal::AbstractTensorInterfaceDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_

View File

@ -722,6 +722,19 @@ inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
return down_cast<EagerContext*>(context);
}
namespace internal {
struct EagerContextDeleter {
void operator()(EagerContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using EagerContextPtr =
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_