Moving RAII helpers for TensorHandle, Tensor, and Operation to their respective classes.
PiperOrigin-RevId: 317578771 Change-Id: Iaf674696ea7d7dfdf94924f4c60d555a613c5f57
This commit is contained in:
parent
f45d6083b7
commit
ebf57bdfc7
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_function.h"
|
||||
|
@ -64,6 +65,19 @@ class AbstractContext {
|
|||
const AbstractContextKind kind_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractContextDeleter {
|
||||
void operator()(AbstractContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractContextPtr =
|
||||
std::unique_ptr<AbstractContext, internal::AbstractContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
|
@ -110,6 +112,19 @@ class AbstractOperation {
|
|||
const AbstractOperationKind kind_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractOperationDeleter {
|
||||
void operator()(AbstractOperation* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractOpPtr =
|
||||
std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
|
@ -40,6 +42,20 @@ class AbstractTensorHandle {
|
|||
const AbstractTensorHandleKind kind_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractTensorHandleDeleter {
|
||||
void operator()(AbstractTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractTensorHandlePtr =
|
||||
std::unique_ptr<AbstractTensorHandle,
|
||||
internal::AbstractTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
|
@ -107,6 +108,20 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||
~ImmediateExecutionContext() override {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct ImmediateExecutionContextDeleter {
|
||||
void operator()(ImmediateExecutionContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using ImmediateContextPtr =
|
||||
std::unique_ptr<ImmediateExecutionContext,
|
||||
internal::ImmediateExecutionContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
@ -48,6 +50,20 @@ class ImmediateExecutionOperation : public AbstractOperation {
|
|||
~ImmediateExecutionOperation() override {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct ImmediateExecutionOperationDeleter {
|
||||
void operator()(ImmediateExecutionOperation* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using ImmediateOpPtr =
|
||||
std::unique_ptr<ImmediateExecutionOperation,
|
||||
internal::ImmediateExecutionOperationDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
|
|
|
@ -59,6 +59,20 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
|||
~ImmediateExecutionTensorHandle() override {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct ImmediateExecutionTensorHandleDeleter {
|
||||
void operator()(ImmediateExecutionTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using ImmediateTensorHandlePtr =
|
||||
std::unique_ptr<ImmediateExecutionTensorHandle,
|
||||
internal::ImmediateExecutionTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
||||
|
|
|
@ -14,44 +14,6 @@ package(
|
|||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_eager_op",
|
||||
hdrs = [
|
||||
"owned_eager_op.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_tensor_handle",
|
||||
hdrs = [
|
||||
"owned_tensor_handle.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_eager_context",
|
||||
hdrs = ["owned_eager_context.h"],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_tensor",
|
||||
hdrs = ["owned_tensor.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tensor_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable_ops",
|
||||
srcs = [
|
||||
|
@ -61,10 +23,10 @@ cc_library(
|
|||
"variable_ops.h",
|
||||
],
|
||||
deps = [
|
||||
":owned_eager_op",
|
||||
":owned_tensor_handle",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -79,11 +41,11 @@ tf_cc_test(
|
|||
"variable_ops_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":owned_eager_context",
|
||||
":owned_tensor",
|
||||
":owned_tensor_handle",
|
||||
":variable_ops",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
|
|
@ -1,54 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct ImmediateExecutionContextDeleter {
|
||||
void operator()(ImmediateExecutionContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct EagerContextDeleter {
|
||||
void operator()(EagerContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractContextPtr =
|
||||
std::unique_ptr<ImmediateExecutionContext,
|
||||
internal::ImmediateExecutionContextDeleter>;
|
||||
|
||||
using EagerContextPtr =
|
||||
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
|
@ -1,42 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct ImmediateExecutionOperationDeleter {
|
||||
void operator()(ImmediateExecutionOperation* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractOpPtr =
|
||||
std::unique_ptr<ImmediateExecutionOperation,
|
||||
internal::ImmediateExecutionOperationDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
|
@ -1,42 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct AbstractTensorInterfaceDeleter {
|
||||
void operator()(AbstractTensorInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractTensorPtr =
|
||||
std::unique_ptr<AbstractTensorInterface,
|
||||
internal::AbstractTensorInterfaceDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
|
|
@ -1,54 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct TensorHandleDeleter {
|
||||
void operator()(TensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct AbstractTensorHandleDeleter {
|
||||
void operator()(ImmediateExecutionTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using TensorHandlePtr =
|
||||
std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>;
|
||||
|
||||
using AbstractTensorHandlePtr =
|
||||
std::unique_ptr<ImmediateExecutionTensorHandle,
|
||||
internal::AbstractTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
|
@ -16,10 +16,11 @@ limitations under the License.
|
|||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
@ -35,8 +36,8 @@ static const char kNoSharingResourceID[] =
|
|||
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle) {
|
||||
AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
ImmediateTensorHandlePtr* handle) {
|
||||
ImmediateOpPtr varhandle_op(ctx->CreateOperation());
|
||||
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
|
||||
|
@ -55,17 +56,19 @@ Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
|||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Execute(
|
||||
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
|
||||
if (var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) {
|
||||
AbstractTensorHandlePtr owned_var_handle(var_handle);
|
||||
if (owned_var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(var_handle));
|
||||
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
|
||||
owned_var_handle.release()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status AssignVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateExecutionTensorHandle* value) {
|
||||
AbstractOpPtr assign_op(ctx->CreateOperation());
|
||||
ImmediateOpPtr assign_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
|
||||
|
@ -78,8 +81,8 @@ Status AssignVariable(ImmediateExecutionContext* ctx,
|
|||
|
||||
Status ReadVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, AbstractTensorHandlePtr* output) {
|
||||
AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
DataType dtype, ImmediateTensorHandlePtr* output) {
|
||||
ImmediateOpPtr read_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
|
||||
|
@ -88,16 +91,18 @@ Status ReadVariable(ImmediateExecutionContext* ctx,
|
|||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
|
||||
if (value->getKind() != ImmediateExecutionTensorHandle::kKind) {
|
||||
AbstractTensorHandlePtr owned_value(value);
|
||||
if (owned_value->getKind() != ImmediateExecutionTensorHandle::kKind) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
output->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(value));
|
||||
output->reset(
|
||||
reinterpret_cast<ImmediateExecutionTensorHandle*>(owned_value.release()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status DestroyResource(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* handle) {
|
||||
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
ImmediateOpPtr destroy_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
|
||||
TF_RETURN_IF_ERROR(destroy_op->AddInput(handle));
|
||||
|
|
|
@ -18,7 +18,6 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
@ -32,7 +31,7 @@ namespace internal {
|
|||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle);
|
||||
ImmediateTensorHandlePtr* handle);
|
||||
|
||||
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
|
||||
// with `variable_handle` with `value`. `dtype` must be the datatype of the
|
||||
|
@ -48,7 +47,7 @@ Status AssignVariable(ImmediateExecutionContext* ctx,
|
|||
// the dtype of the variable associated with `variable_handle`.
|
||||
Status ReadVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, AbstractTensorHandlePtr* output);
|
||||
DataType dtype, ImmediateTensorHandlePtr* output);
|
||||
|
||||
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
|
||||
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
|
||||
|
|
|
@ -17,9 +17,8 @@ limitations under the License.
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -30,10 +29,10 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
AbstractTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
|
||||
float value) {
|
||||
ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
|
||||
float value) {
|
||||
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
|
||||
AbstractTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
|
||||
ImmediateTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
|
||||
return handle;
|
||||
}
|
||||
|
||||
|
@ -62,7 +61,7 @@ class VariableOpsTest : public ::testing::Test {
|
|||
// Sanity check for variable creation
|
||||
TEST_F(VariableOpsTest, CreateVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr handle;
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
// The created TensorHandle should be a DT_Resource
|
||||
|
@ -72,7 +71,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) {
|
|||
// Sanity check for variable destruction
|
||||
TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr handle;
|
||||
ImmediateTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
|
||||
|
@ -83,18 +82,18 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
|
|||
// Sanity check for handle assignment and reading
|
||||
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr variable;
|
||||
ImmediateTensorHandlePtr variable;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &variable));
|
||||
|
||||
// Create a Scalar float TensorHandle with value 42, and assign it to
|
||||
// the variable.
|
||||
AbstractTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
|
||||
ImmediateTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
|
||||
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
|
||||
my_value.get()));
|
||||
|
||||
// Read back the value from the variable, and check that it is 42.
|
||||
AbstractTensorHandlePtr read_value_handle;
|
||||
ImmediateTensorHandlePtr read_value_handle;
|
||||
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
|
||||
&read_value_handle));
|
||||
Status status;
|
||||
|
|
|
@ -54,6 +54,20 @@ class AbstractTensorInterface {
|
|||
virtual ~AbstractTensorInterface() {}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
struct AbstractTensorInterfaceDeleter {
|
||||
void operator()(AbstractTensorInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using AbstractTensorPtr =
|
||||
std::unique_ptr<AbstractTensorInterface,
|
||||
internal::AbstractTensorInterfaceDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_
|
||||
|
|
|
@ -722,6 +722,19 @@ inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
|
|||
return down_cast<EagerContext*>(context);
|
||||
}
|
||||
|
||||
namespace internal {
|
||||
struct EagerContextDeleter {
|
||||
void operator()(EagerContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
using EagerContextPtr =
|
||||
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_
|
||||
|
|
Loading…
Reference in New Issue