Split Abstract interfaces into Abstract and ImmediateExecution interfaces.
The Abstract interfaces are shared with tracing mode. Introduce an AbstractFunction which handles the conversion between MLIR function and FunctionDef and the runtime can query whichever representation is suitable. Right now this only supports GetFunctionDef but an API for fetching the MLIR function directly will be added in future changes. PiperOrigin-RevId: 316942774 Change-Id: I1abebbe853b98dd0048bab9fc092252f4caf3d1b
This commit is contained in:
parent
f24487e619
commit
8d5171bad7
|
@ -624,7 +624,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
|||
|
||||
const int num_inputs = input_shapes->num_items;
|
||||
NodeDef node_def;
|
||||
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op);
|
||||
tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
|
||||
node_def.set_name(op->Name());
|
||||
node_def.set_op(op->Name());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
|
|
|
@ -38,9 +38,10 @@ tf_cuda_library(
|
|||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
":context_interface",
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":immediate_execution_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
":abstract_tensor_handle",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_executor_internal",
|
||||
|
@ -101,13 +102,17 @@ tf_cuda_library(
|
|||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"abstract_context.h",
|
||||
"abstract_function.h",
|
||||
"abstract_operation.h",
|
||||
"abstract_tensor_handle.h",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.h",
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
|
@ -163,12 +168,22 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensor_handle_interface",
|
||||
hdrs = ["tensor_handle_interface.h"],
|
||||
name = "abstract_tensor_handle",
|
||||
hdrs = ["abstract_tensor_handle.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_tensor_handle",
|
||||
hdrs = ["immediate_execution_tensor_handle.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -177,13 +192,13 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "operation_interface",
|
||||
hdrs = ["operation_interface.h"],
|
||||
name = "abstract_operation",
|
||||
hdrs = ["abstract_operation.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -193,14 +208,58 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "context_interface",
|
||||
hdrs = ["context_interface.h"],
|
||||
name = "immediate_execution_operation",
|
||||
hdrs = ["immediate_execution_operation.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_context",
|
||||
hdrs = ["abstract_context.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_function",
|
||||
":abstract_operation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_function",
|
||||
hdrs = ["abstract_function.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "immediate_execution_context",
|
||||
hdrs = ["immediate_execution_context.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -217,7 +276,7 @@ cc_library(
|
|||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":context_interface",
|
||||
":immediate_execution_context",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
@ -277,7 +336,7 @@ cc_library(
|
|||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":operation_interface",
|
||||
":immediate_execution_operation",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
@ -300,7 +359,7 @@ cc_library(
|
|||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tensor_handle_interface",
|
||||
":immediate_execution_tensor_handle",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
],
|
||||
)
|
||||
|
@ -480,6 +539,9 @@ tf_cuda_library(
|
|||
":tfe_context_internal",
|
||||
":tfe_op_internal",
|
||||
":tfe_tensorhandle_internal",
|
||||
":abstract_operation",
|
||||
":abstract_context",
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_function.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a context.
|
||||
//
|
||||
// This serves as a factory for creating `AbstractOperation`s and for
|
||||
// registering traced functions.
|
||||
// Operations creation within a context can only be executed in that context
|
||||
// (for now at least).
|
||||
// Implementations of the context may contain some state e.g. an execution
|
||||
// environment, a traced representation etc.
|
||||
class AbstractContext {
|
||||
protected:
|
||||
enum AbstractContextKind { kTracing, kImmediateExecution };
|
||||
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractContext() {}
|
||||
|
||||
public:
|
||||
AbstractContextKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus clients MUST call Release() in order to
|
||||
// destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
// Creates an operation builder and ties it to this context.
|
||||
// The returned object can be used for setting operation's attributes,
|
||||
// adding inputs and finally executing (immediately or lazily as in tracing)
|
||||
// it in this context.
|
||||
virtual AbstractOperation* CreateOperation() = 0;
|
||||
|
||||
// Registers a function with this context, after this the function is
|
||||
// available to be called/referenced by its name in this context.
|
||||
virtual Status RegisterFunction(AbstractFunction*) = 0;
|
||||
// Remove a function. 'func' argument is the name of a previously added
|
||||
// FunctionDef. The name is in fdef.signature.name.
|
||||
virtual Status RemoveFunction(const string& func) = 0;
|
||||
|
||||
private:
|
||||
const AbstractContextKind kind_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
|
|
@ -0,0 +1,46 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A traced function: this hides the complexity of converting the serialized
|
||||
// representation between various supported formats e.g. FunctionDef and Mlir
|
||||
// function.
|
||||
class AbstractFunction {
|
||||
protected:
|
||||
enum AbstractFunctionKind { kGraphFunc, kMlirFunc };
|
||||
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
|
||||
|
||||
public:
|
||||
// Returns which subclass is this instance of.
|
||||
AbstractFunctionKind getKind() const { return kind_; }
|
||||
virtual ~AbstractFunction() = default;
|
||||
|
||||
// Returns the AbstractFunction as a FunctionDef.
|
||||
virtual Status GetFunctionDef(FunctionDef**) = 0;
|
||||
|
||||
private:
|
||||
const AbstractFunctionKind kind_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
|
|
@ -12,24 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class AbstractOperationInterface {
|
||||
// This interface allows building and executing an operation in either
|
||||
// tracing or immediate execution mode.
|
||||
class AbstractOperation {
|
||||
protected:
|
||||
enum AbstractOperationKind { kTracing, kImmediateExecution };
|
||||
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractOperation() {}
|
||||
|
||||
public:
|
||||
AbstractOperationKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
|
@ -38,7 +43,6 @@ class AbstractOperationInterface {
|
|||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
virtual void Clear() = 0;
|
||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||
|
||||
virtual const string& Name() const = 0;
|
||||
|
@ -66,12 +70,10 @@ class AbstractOperationInterface {
|
|||
// existing and given constraints will be performed.
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0;
|
||||
virtual Status AddInputList(
|
||||
absl::Span<AbstractTensorHandleInterface*> inputs) = 0;
|
||||
virtual Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
|
||||
virtual Status AddInput(AbstractTensorHandle* input) = 0;
|
||||
virtual Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) = 0;
|
||||
virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) = 0;
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) = 0;
|
||||
|
@ -82,7 +84,7 @@ class AbstractOperationInterface {
|
|||
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperationInterface* value) = 0;
|
||||
const AbstractOperation* value) = 0;
|
||||
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) = 0;
|
||||
virtual Status SetAttrTensor(const char* attr_name,
|
||||
|
@ -102,19 +104,12 @@ class AbstractOperationInterface {
|
|||
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) = 0;
|
||||
virtual Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperationInterface*> values) = 0;
|
||||
const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
|
||||
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
private:
|
||||
const AbstractOperationKind kind_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
|
|
@ -0,0 +1,45 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
// execution mode.
|
||||
class AbstractTensorHandle {
|
||||
protected:
|
||||
enum AbstractTensorHandleKind { kTracing, kImmediateExecution };
|
||||
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractTensorHandle() {}
|
||||
|
||||
public:
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
private:
|
||||
const AbstractTensorHandleKind kind_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
|
|
@ -21,6 +21,8 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
|
||||
// clang-format off
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
@ -31,8 +33,8 @@ limitations under the License.
|
|||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
|
@ -1119,7 +1121,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
|||
|
||||
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
TF_Status* status) {
|
||||
tensorflow::AbstractOperationInterface* new_op =
|
||||
tensorflow::ImmediateExecutionOperation* new_op =
|
||||
tensorflow::unwrap(ctx)->CreateOperation();
|
||||
status->status = new_op->Reset(op_or_function_name, nullptr);
|
||||
if (!status->status.ok()) {
|
||||
|
@ -1164,7 +1166,9 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
|||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->AddInputList(
|
||||
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)});
|
||||
{reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(inputs)),
|
||||
static_cast<size_t>(num_inputs)});
|
||||
}
|
||||
|
||||
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
|
@ -1324,7 +1328,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
|||
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
||||
const TFE_Op** value, int num_values) {
|
||||
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
|
||||
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)});
|
||||
attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
|
||||
tensorflow::unwrap(value)),
|
||||
static_cast<size_t>(num_values)});
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
||||
}
|
||||
|
@ -1368,7 +1374,10 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
|||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->Execute(
|
||||
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals);
|
||||
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(retvals)),
|
||||
*num_retvals),
|
||||
num_retvals);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
|
|
|
@ -38,7 +38,7 @@ using tensorflow::string;
|
|||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
tensorflow::AbstractOperationInterface* op =
|
||||
tensorflow::ImmediateExecutionOperation* op =
|
||||
tensorflow::unwrap(op_to_reset);
|
||||
op->Clear();
|
||||
status->status = op->Reset(op_or_function_name, raw_device_name);
|
||||
|
|
|
@ -12,15 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
|
@ -34,16 +35,9 @@ namespace tensorflow {
|
|||
//
|
||||
// A context is responsible for creating key objects such as Tensors,
|
||||
// TensorHandles & Operations.
|
||||
class AbstractContextInterface {
|
||||
class ImmediateExecutionContext : public AbstractContext {
|
||||
public:
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus clients MUST call Release() in order to
|
||||
// destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
|
||||
static constexpr AbstractContextKind kKind = kImmediateExecution;
|
||||
// Optimized scalar creation functions
|
||||
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
|
||||
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
|
||||
|
@ -74,15 +68,15 @@ class AbstractContextInterface {
|
|||
void* memory_releaser_arg) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) = 0;
|
||||
// Copy the handle to another device.
|
||||
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice(
|
||||
AbstractTensorHandleInterface* handle, const char* device_name,
|
||||
virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
|
||||
ImmediateExecutionTensorHandle* handle, const char* device_name,
|
||||
Status* status) = 0;
|
||||
|
||||
// Create an operation to perform op execution
|
||||
virtual AbstractOperationInterface* CreateOperation() = 0;
|
||||
ImmediateExecutionOperation* CreateOperation() override = 0;
|
||||
|
||||
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
|
||||
// Runtime. This is necessary to decouple runtime-dependent
|
||||
|
@ -107,14 +101,12 @@ class AbstractContextInterface {
|
|||
// be executed as an op. Return error if the function with the same name
|
||||
// already exists.
|
||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||
// Remove a function. 'func' argument is the name of a previously added
|
||||
// FunctionDef. The name is in fdef.signature.name.
|
||||
virtual Status RemoveFunction(const string& func) = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
ImmediateExecutionContext() : AbstractContext(kKind) {}
|
||||
~ImmediateExecutionContext() override {}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
|
|
@ -0,0 +1,53 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
struct TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class ImmediateExecutionOperation : public AbstractOperation {
|
||||
public:
|
||||
static constexpr AbstractOperationKind kKind = kImmediateExecution;
|
||||
virtual void Clear() = 0;
|
||||
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual Status SetUseXla(bool enable) = 0;
|
||||
|
||||
protected:
|
||||
ImmediateExecutionOperation() : AbstractOperation(kKind) {}
|
||||
~ImmediateExecutionOperation() override {}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
|
|
@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
@ -30,15 +31,9 @@ namespace tensorflow {
|
|||
// files. The interface lists the common functionality that must be provided by
|
||||
// any concrete implementation. However, in cases where the true concrete class
|
||||
// is needed a static_cast can be applied.
|
||||
class AbstractTensorHandleInterface {
|
||||
class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
|
||||
public:
|
||||
// Release any underlying resources, including the interface object.
|
||||
//
|
||||
// WARNING: The destructor of this class is marked as protected to disallow
|
||||
// clients from directly destroying this object since it may manage it's own
|
||||
// lifetime through ref counting. Thus this must be allocated on the heap and
|
||||
// clients MUST call Release() in order to destroy an instance of this class.
|
||||
virtual void Release() = 0;
|
||||
static constexpr AbstractTensorHandleKind kKind = kImmediateExecution;
|
||||
|
||||
// Returns tensor dtype.
|
||||
virtual tensorflow::DataType DataType() const = 0;
|
||||
|
@ -57,12 +52,13 @@ class AbstractTensorHandleInterface {
|
|||
virtual AbstractTensorInterface* Resolve(Status* status) = 0;
|
||||
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
virtual ImmediateExecutionTensorHandle* Copy() = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
ImmediateExecutionTensorHandle() : AbstractTensorHandle(kKind) {}
|
||||
~ImmediateExecutionTensorHandle() override {}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
|
||||
// Wraps a pointer to a context implementation.
|
||||
//
|
||||
|
@ -28,7 +28,7 @@ typedef struct TFE_Context TFE_Context;
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionContext, TFE_Context);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
|
||||
// Wraps a pointer to an operation implementation.
|
||||
//
|
||||
|
@ -28,8 +28,8 @@ typedef struct TFE_Op TFE_Op;
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation, TFE_Op);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation*, TFE_Op*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
// Wraps a pointer to a tensor handle implementation.
|
||||
//
|
||||
|
@ -28,9 +28,9 @@ typedef struct TFE_TensorHandle TFE_TensorHandle;
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface,
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle,
|
||||
TFE_TensorHandle);
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*,
|
||||
DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle*,
|
||||
TFE_TensorHandle*);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -23,8 +23,8 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":function_metadata",
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,12 +15,12 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>&
|
||||
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>&
|
||||
ConcreteFunction::GetCaptures() const {
|
||||
return captures_;
|
||||
}
|
||||
|
|
|
@ -18,8 +18,8 @@ limitations under the License.
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
|
||||
|
@ -38,15 +38,15 @@ class ConcreteFunction {
|
|||
virtual ~ConcreteFunction() = 0;
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
virtual AbstractOperationInterface* GetCallOp() = 0;
|
||||
virtual ImmediateExecutionOperation* GetCallOp() = 0;
|
||||
|
||||
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures()
|
||||
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>& GetCaptures()
|
||||
const;
|
||||
const FunctionMetadata& GetFunctionMetadata() const;
|
||||
|
||||
private:
|
||||
FunctionMetadata metadata_;
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_;
|
||||
std::vector<tensorflow::ImmediateExecutionTensorHandle*> captures_;
|
||||
FunctionDef* function_;
|
||||
};
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ cc_library(
|
|||
"owned_eager_op.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -30,7 +30,7 @@ cc_library(
|
|||
"owned_tensor_handle.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
],
|
||||
)
|
||||
|
@ -39,7 +39,7 @@ cc_library(
|
|||
name = "owned_eager_context",
|
||||
hdrs = ["owned_eager_context.h"],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:context_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
],
|
||||
)
|
||||
|
@ -63,8 +63,9 @@ cc_library(
|
|||
deps = [
|
||||
":owned_eager_op",
|
||||
":owned_tensor_handle",
|
||||
"//tensorflow/c/eager:context_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
|
|
@ -18,14 +18,14 @@ limitations under the License.
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct AbstractContextInterfaceDeleter {
|
||||
void operator()(AbstractContextInterface* p) const {
|
||||
struct ImmediateExecutionContextDeleter {
|
||||
void operator()(ImmediateExecutionContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
|
@ -43,8 +43,8 @@ struct EagerContextDeleter {
|
|||
} // namespace internal
|
||||
|
||||
using AbstractContextPtr =
|
||||
std::unique_ptr<AbstractContextInterface,
|
||||
internal::AbstractContextInterfaceDeleter>;
|
||||
std::unique_ptr<ImmediateExecutionContext,
|
||||
internal::ImmediateExecutionContextDeleter>;
|
||||
|
||||
using EagerContextPtr =
|
||||
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
|
||||
|
|
|
@ -18,13 +18,13 @@ limitations under the License.
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct AbstractOperationInterfaceDeleter {
|
||||
void operator()(AbstractOperationInterface* p) const {
|
||||
struct ImmediateExecutionOperationDeleter {
|
||||
void operator()(ImmediateExecutionOperation* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
|
@ -34,8 +34,8 @@ struct AbstractOperationInterfaceDeleter {
|
|||
} // namespace internal
|
||||
|
||||
using AbstractOpPtr =
|
||||
std::unique_ptr<AbstractOperationInterface,
|
||||
internal::AbstractOperationInterfaceDeleter>;
|
||||
std::unique_ptr<ImmediateExecutionOperation,
|
||||
internal::ImmediateExecutionOperationDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -33,7 +33,7 @@ struct TensorHandleDeleter {
|
|||
};
|
||||
|
||||
struct AbstractTensorHandleDeleter {
|
||||
void operator()(AbstractTensorHandleInterface* p) const {
|
||||
void operator()(ImmediateExecutionTensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ using TensorHandlePtr =
|
|||
std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>;
|
||||
|
||||
using AbstractTensorHandlePtr =
|
||||
std::unique_ptr<AbstractTensorHandleInterface,
|
||||
std::unique_ptr<ImmediateExecutionTensorHandle,
|
||||
internal::AbstractTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -16,7 +16,8 @@ limitations under the License.
|
|||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
@ -32,7 +33,7 @@ namespace internal {
|
|||
static const char kNoSharingResourceID[] =
|
||||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
|
||||
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle) {
|
||||
AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
|
@ -50,17 +51,20 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
|||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString(
|
||||
"shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID)));
|
||||
|
||||
AbstractTensorHandleInterface* var_handle = nullptr;
|
||||
AbstractTensorHandle* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Execute(
|
||||
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
|
||||
handle->reset(var_handle);
|
||||
if (var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(var_handle));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status AssignVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
DataType dtype, AbstractTensorHandleInterface* value) {
|
||||
Status AssignVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateExecutionTensorHandle* value) {
|
||||
AbstractOpPtr assign_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
|
||||
|
@ -72,24 +76,27 @@ Status AssignVariable(AbstractContextInterface* ctx,
|
|||
return Status();
|
||||
}
|
||||
|
||||
Status ReadVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
Status ReadVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, AbstractTensorHandlePtr* output) {
|
||||
AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
|
||||
|
||||
AbstractTensorHandleInterface* value = nullptr;
|
||||
AbstractTensorHandle* value = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
|
||||
output->reset(value);
|
||||
if (value->getKind() != ImmediateExecutionTensorHandle::kKind) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
output->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(value));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status DestroyResource(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* handle) {
|
||||
Status DestroyResource(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* handle) {
|
||||
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
|
||||
|
|
|
@ -16,8 +16,8 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
||||
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
@ -30,7 +30,7 @@ namespace internal {
|
|||
// TensorHandle associated with the variable. This is equivalent to creating an
|
||||
// unitialized TF2 tf.Variable.
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
|
||||
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
||||
Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle);
|
||||
|
||||
|
@ -39,22 +39,22 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
|||
// underlying variable for `variable_handle`. Note that it is illegal to assign
|
||||
// a variable to a Tensor with a different dtype than what the variable was
|
||||
// created with.
|
||||
Status AssignVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
DataType dtype, AbstractTensorHandleInterface* value);
|
||||
Status AssignVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, ImmediateExecutionTensorHandle* value);
|
||||
|
||||
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable
|
||||
// value of `variable_handle` and copies the value to `output`. `dtype` must be
|
||||
// the dtype of the variable associated with `variable_handle`.
|
||||
Status ReadVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
Status ReadVariable(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* variable_handle,
|
||||
DataType dtype, AbstractTensorHandlePtr* output);
|
||||
|
||||
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
|
||||
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290
|
||||
Status DestroyResource(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* handle);
|
||||
Status DestroyResource(ImmediateExecutionContext* ctx,
|
||||
ImmediateExecutionTensorHandle* handle);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -178,7 +178,7 @@ cc_library(
|
|||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
],
|
||||
)
|
||||
|
@ -190,7 +190,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/conversion_macros.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
// Internal structures used by the SavedModel C API. These are likely to
|
||||
// change and should not be depended on.
|
||||
|
@ -29,7 +29,7 @@ typedef struct TF_TensorHandleList TF_TensorHandleList;
|
|||
namespace tensorflow {
|
||||
|
||||
DEFINE_CONVERSION_FUNCTIONS(
|
||||
std::vector<tensorflow::AbstractTensorHandleInterface*>,
|
||||
std::vector<tensorflow::ImmediateExecutionTensorHandle*>,
|
||||
TF_TensorHandleList)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -32,6 +32,8 @@ tf_cuda_library(
|
|||
":tensor_handle",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:abstract_function",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -74,9 +76,9 @@ tf_cuda_library(
|
|||
":kernel_and_device",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:context_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
] + select({
|
||||
|
@ -137,8 +139,10 @@ tf_cuda_library(
|
|||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -211,7 +215,7 @@ tf_cuda_library(
|
|||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
|
@ -496,6 +500,8 @@ cc_library(
|
|||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/compiler/jit:common",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/c/eager:abstract_function",
|
||||
"//tensorflow/core/platform:errors",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
|
|
|
@ -30,8 +30,6 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
||||
#include "tensorflow/core/common_runtime/colocation_graph.h"
|
||||
|
|
|
@ -33,7 +33,7 @@ limitations under the License.
|
|||
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
|
@ -135,7 +135,7 @@ class CustomDevice {
|
|||
// TensorHandles may be placed either on custom or physical devices.
|
||||
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
||||
|
||||
class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
public:
|
||||
static constexpr uint64 kInvalidContextId = 0;
|
||||
|
||||
|
@ -178,12 +178,14 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||
MemoryReleaser memory_releaser,
|
||||
void* memory_releaser_arg) override;
|
||||
|
||||
AbstractTensorHandleInterface* CreateLocalHandle(
|
||||
ImmediateExecutionTensorHandle* CreateLocalHandle(
|
||||
AbstractTensorInterface* t) override;
|
||||
AbstractTensorHandleInterface* CopyTensorHandleToDevice(
|
||||
AbstractTensorHandleInterface* handle, const char* device_name,
|
||||
ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
|
||||
ImmediateExecutionTensorHandle* handle, const char* device_name,
|
||||
Status* status) override;
|
||||
AbstractOperationInterface* CreateOperation() override;
|
||||
ImmediateExecutionOperation* CreateOperation() override;
|
||||
|
||||
Status RegisterFunction(AbstractFunction* f) override;
|
||||
|
||||
bool UsesTFRT() override;
|
||||
|
||||
|
@ -716,7 +718,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||
std::function<void()> resource_deallocator_ = nullptr;
|
||||
};
|
||||
|
||||
inline EagerContext* ContextFromInterface(AbstractContextInterface* context) {
|
||||
inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
|
||||
return down_cast<EagerContext*>(context);
|
||||
}
|
||||
|
||||
|
|
|
@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/abstract_function.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -112,8 +114,8 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
|
|||
}
|
||||
}
|
||||
|
||||
AbstractTensorHandleInterface* EagerContext::CopyTensorHandleToDevice(
|
||||
AbstractTensorHandleInterface* handle, const char* device_name,
|
||||
ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
|
||||
ImmediateExecutionTensorHandle* handle, const char* device_name,
|
||||
Status* status) {
|
||||
TensorHandle* input = TensorHandleFromInterface(handle);
|
||||
TensorHandle* result = nullptr;
|
||||
|
@ -158,7 +160,7 @@ AbstractTensorHandleInterface* EagerContext::CopyTensorHandleToDevice(
|
|||
// here to a circular BUILD dep issue. If we move this to context.cc, then we
|
||||
// will have the circular dependency of:
|
||||
// context -> tensor_handle -> remote_tensor_handle_data -> context
|
||||
AbstractTensorHandleInterface* EagerContext::CreateLocalHandle(
|
||||
ImmediateExecutionTensorHandle* EagerContext::CreateLocalHandle(
|
||||
AbstractTensorInterface* t) {
|
||||
Tensor tensor = TensorFromInterface(t);
|
||||
return TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/HostCPU(),
|
||||
|
@ -168,14 +170,23 @@ AbstractTensorHandleInterface* EagerContext::CreateLocalHandle(
|
|||
// TODO(b/152902651): We have to keep this function here since EagerOperation
|
||||
// depends on EagerContext. Thus, the context build target can't depend on
|
||||
// EagerOperation.
|
||||
AbstractOperationInterface* EagerContext::CreateOperation() {
|
||||
ImmediateExecutionOperation* EagerContext::CreateOperation() {
|
||||
return new EagerOperation(this);
|
||||
}
|
||||
|
||||
Status EagerContext::RegisterFunction(AbstractFunction* f) {
|
||||
FunctionDef* fdef;
|
||||
TF_RETURN_IF_ERROR(f->GetFunctionDef(&fdef));
|
||||
if (!fdef) {
|
||||
return errors::InvalidArgument("GetFunctionDef returned nullptr.");
|
||||
}
|
||||
return AddFunctionDef(*fdef);
|
||||
}
|
||||
|
||||
// TODO(b/152902651): Once we move many execute.cc functions into
|
||||
// eager_operation.cc we can avoid a circular dependency between them.
|
||||
Status EagerOperation::Execute(
|
||||
absl::Span<AbstractTensorHandleInterface*> retvals, int* num_retvals) {
|
||||
Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
return EagerExecute(
|
||||
this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()),
|
||||
num_retvals);
|
||||
|
|
|
@ -15,7 +15,9 @@ limitations under the License.
|
|||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
|
@ -91,8 +93,8 @@ Status EagerOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerOperation::SetAttrFunction(
|
||||
const char* attr_name, const AbstractOperationInterface* value) {
|
||||
Status EagerOperation::SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) {
|
||||
AttrValue attr_value;
|
||||
NameAttrList* func = attr_value.mutable_func();
|
||||
func->set_name(value->Name());
|
||||
|
@ -194,8 +196,7 @@ Status EagerOperation::SetAttrShapeList(const char* attr_name,
|
|||
}
|
||||
|
||||
Status EagerOperation::SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperationInterface*> values) {
|
||||
const char* attr_name, absl::Span<const AbstractOperation*> values) {
|
||||
size_t num_values = values.size();
|
||||
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
|
||||
for (int i = 0; i < num_values; i++) {
|
||||
|
@ -253,14 +254,13 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) {
|
||||
Status EagerOperation::AddInput(AbstractTensorHandle* input) {
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
AddTensorHandle(h);
|
||||
return MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
|
||||
Status EagerOperation::AddInputList(
|
||||
absl::Span<AbstractTensorHandleInterface*> inputs) {
|
||||
Status EagerOperation::AddInputList(absl::Span<AbstractTensorHandle*> inputs) {
|
||||
for (auto& input : inputs) {
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
AddTensorHandle(h);
|
||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
|
@ -31,7 +32,7 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
class EagerOperation : public AbstractOperationInterface {
|
||||
class EagerOperation : public ImmediateExecutionOperation {
|
||||
public:
|
||||
explicit EagerOperation(tensorflow::EagerContext* ctx) : ctx_(*ctx) {}
|
||||
~EagerOperation() override {
|
||||
|
@ -56,7 +57,7 @@ class EagerOperation : public AbstractOperationInterface {
|
|||
}
|
||||
|
||||
// Replaces the previous device name with the given one (see
|
||||
// AbstractOperationInterface::SetDeviceName for more details).
|
||||
// AbstractOperation::SetDeviceName for more details).
|
||||
//
|
||||
// This also resets the internal device pointer, unless the given name refers
|
||||
// to a known custom device, in which case the internal device pointer is
|
||||
|
@ -76,10 +77,9 @@ class EagerOperation : public AbstractOperationInterface {
|
|||
|
||||
Status SetAttrValue(const char* attr_name, const AttrValue& value);
|
||||
|
||||
Status AddInput(AbstractTensorHandleInterface* input) override;
|
||||
Status AddInputList(
|
||||
absl::Span<AbstractTensorHandleInterface*> inputs) override;
|
||||
Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
|
||||
Status AddInput(AbstractTensorHandle* input) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override;
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
const tensorflow::OpDef* OpDef() const override { return op_def_; };
|
||||
|
||||
|
@ -92,7 +92,7 @@ class EagerOperation : public AbstractOperationInterface {
|
|||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override;
|
||||
Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperationInterface* value) override;
|
||||
const AbstractOperation* value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name,
|
||||
|
@ -111,7 +111,7 @@ class EagerOperation : public AbstractOperationInterface {
|
|||
const int* num_dims, int num_values) override;
|
||||
Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperationInterface*> values) override;
|
||||
absl::Span<const AbstractOperation*> values) override;
|
||||
|
||||
Status InputLength(const char* input_name, int* length) override;
|
||||
Status OutputLength(const char* output_name, int* length) override;
|
||||
|
@ -235,7 +235,7 @@ inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
|
|||
}
|
||||
|
||||
inline EagerOperation* OperationFromInterface(
|
||||
AbstractOperationInterface* operation) {
|
||||
ImmediateExecutionOperation* operation) {
|
||||
return down_cast<EagerOperation*>(operation);
|
||||
}
|
||||
|
||||
|
|
|
@ -1071,7 +1071,7 @@ const char* TensorHandle::BackingDeviceName(Status* status) const {
|
|||
}
|
||||
}
|
||||
|
||||
tensorflow::AbstractTensorHandleInterface* TensorHandle::Copy() {
|
||||
tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() {
|
||||
Ref();
|
||||
return this;
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ limitations under the License.
|
|||
// clang-format on
|
||||
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
|
||||
|
@ -53,7 +53,7 @@ class EagerContext;
|
|||
// Associates a Tensor and a Device, used in the eager runtime. Internal version
|
||||
// of the TFE_TensorHandle struct and the python EagerTensor class
|
||||
// (unrelated to python TensorHandle).
|
||||
class TensorHandle : public AbstractTensorHandleInterface,
|
||||
class TensorHandle : public ImmediateExecutionTensorHandle,
|
||||
public core::RefCounted {
|
||||
// TensorHandle for dtype != DT_RESOURCE
|
||||
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
||||
|
@ -121,7 +121,7 @@ class TensorHandle : public AbstractTensorHandleInterface,
|
|||
const char* BackingDeviceName(Status* status) const override;
|
||||
AbstractTensorInterface* Resolve(Status* status) override;
|
||||
|
||||
AbstractTensorHandleInterface* Copy() override;
|
||||
ImmediateExecutionTensorHandle* Copy() override;
|
||||
|
||||
// Return the Tensor from the default device.
|
||||
Status Tensor(const tensorflow::Tensor** t) const;
|
||||
|
@ -372,12 +372,12 @@ const VariantDevice kVariantDeviceNull = static_cast<Device*>(nullptr);
|
|||
// Returns the device backing the resource. Else, returns nullptr.
|
||||
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
|
||||
|
||||
class TensorHandleInterface : public AbstractTensorHandleInterface {
|
||||
class TensorHandleInterface : public ImmediateExecutionTensorHandle {
|
||||
public:
|
||||
};
|
||||
|
||||
inline TensorHandle* TensorHandleFromInterface(
|
||||
AbstractTensorHandleInterface* handle) {
|
||||
template <typename T>
|
||||
inline TensorHandle* TensorHandleFromInterface(T* handle) {
|
||||
return down_cast<TensorHandle*>(handle);
|
||||
}
|
||||
|
||||
|
|
|
@ -2008,7 +2008,7 @@ bool ListContainsNone(PyObject* list) {
|
|||
|
||||
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
|
||||
if (EagerTensor_CheckExact(tensor)) {
|
||||
tensorflow::AbstractTensorHandleInterface* handle =
|
||||
tensorflow::ImmediateExecutionTensorHandle* handle =
|
||||
tensorflow::unwrap(EagerTensor_Handle(tensor));
|
||||
tensorflow::int64 id = PyEagerTensor_ID(tensor);
|
||||
tensorflow::DataType dtype =
|
||||
|
@ -3869,7 +3869,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
|
|||
bool include_tensor_ranks_only,
|
||||
EncodeResult* result) {
|
||||
if (EagerTensor_CheckExact(arg)) {
|
||||
tensorflow::AbstractTensorHandleInterface* handle =
|
||||
tensorflow::ImmediateExecutionTensorHandle* handle =
|
||||
tensorflow::unwrap(EagerTensor_Handle(arg));
|
||||
|
||||
absl::StrAppend(&result->str, kDType,
|
||||
|
|
Loading…
Reference in New Issue