Split Abstract interfaces into Abstract and ImmediateExecution interfaces.

The Abstract interfaces are shared with tracing mode.
Introduce an AbstractFunction which handles the conversion between MLIR function and FunctionDef and the runtime can query whichever representation is suitable. Right now this only supports GetFunctionDef but an API for fetching the MLIR function directly will be added in future changes.

PiperOrigin-RevId: 316942774
Change-Id: I1abebbe853b98dd0048bab9fc092252f4caf3d1b
This commit is contained in:
Saurabh Saxena 2020-06-17 12:35:43 -07:00 committed by TensorFlower Gardener
parent f24487e619
commit 8d5171bad7
35 changed files with 489 additions and 197 deletions

View File

@ -624,7 +624,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
const int num_inputs = input_shapes->num_items; const int num_inputs = input_shapes->num_items;
NodeDef node_def; NodeDef node_def;
tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op); tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op);
node_def.set_name(op->Name()); node_def.set_name(op->Name());
node_def.set_op(op->Name()); node_def.set_op(op->Name());
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {

View File

@ -38,9 +38,10 @@ tf_cuda_library(
"//tensorflow/core:portable_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
":context_interface", ":immediate_execution_context",
":operation_interface", ":immediate_execution_operation",
":tensor_handle_interface", ":immediate_execution_tensor_handle",
":abstract_tensor_handle",
":tfe_context_internal", ":tfe_context_internal",
":tfe_cancellation_manager_internal", ":tfe_cancellation_manager_internal",
":tfe_executor_internal", ":tfe_executor_internal",
@ -101,13 +102,17 @@ tf_cuda_library(
filegroup( filegroup(
name = "pywrap_required_hdrs", name = "pywrap_required_hdrs",
srcs = [ srcs = [
"abstract_context.h",
"abstract_function.h",
"abstract_operation.h",
"abstract_tensor_handle.h",
"c_api_experimental.h", "c_api_experimental.h",
"c_api_internal.h", "c_api_internal.h",
"c_api_unified_experimental.h", "c_api_unified_experimental.h",
"context_interface.h",
"dlpack.h", "dlpack.h",
"operation_interface.h", "immediate_execution_context.h",
"tensor_handle_interface.h", "immediate_execution_operation.h",
"immediate_execution_tensor_handle.h",
"tfe_cancellation_manager_internal.h", "tfe_cancellation_manager_internal.h",
"tfe_executor_internal.h", "tfe_executor_internal.h",
"tfe_monitoring_internal.h", "tfe_monitoring_internal.h",
@ -163,12 +168,22 @@ cc_library(
) )
cc_library( cc_library(
name = "tensor_handle_interface", name = "abstract_tensor_handle",
hdrs = ["tensor_handle_interface.h"], hdrs = ["abstract_tensor_handle.h"],
visibility = [
"//tensorflow:internal",
],
deps = [],
)
cc_library(
name = "immediate_execution_tensor_handle",
hdrs = ["immediate_execution_tensor_handle.h"],
visibility = [ visibility = [
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
":abstract_tensor_handle",
"//tensorflow/c:tensor_interface", "//tensorflow/c:tensor_interface",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -177,13 +192,13 @@ cc_library(
) )
cc_library( cc_library(
name = "operation_interface", name = "abstract_operation",
hdrs = ["operation_interface.h"], hdrs = ["abstract_operation.h"],
visibility = [ visibility = [
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
":tensor_handle_interface", ":abstract_tensor_handle",
"//tensorflow/c:tensor_interface", "//tensorflow/c:tensor_interface",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -193,14 +208,58 @@ cc_library(
) )
cc_library( cc_library(
name = "context_interface", name = "immediate_execution_operation",
hdrs = ["context_interface.h"], hdrs = ["immediate_execution_operation.h"],
visibility = [ visibility = [
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
":operation_interface", ":abstract_operation",
":tensor_handle_interface", ":abstract_tensor_handle",
":immediate_execution_tensor_handle",
"//tensorflow/c:tensor_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "abstract_context",
hdrs = ["abstract_context.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_function",
":abstract_operation",
],
)
cc_library(
name = "abstract_function",
hdrs = ["abstract_function.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
],
)
cc_library(
name = "immediate_execution_context",
hdrs = ["immediate_execution_context.h"],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":immediate_execution_operation",
":immediate_execution_tensor_handle",
"//tensorflow/c:tensor_interface", "//tensorflow/c:tensor_interface",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -217,7 +276,7 @@ cc_library(
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
":context_interface", ":immediate_execution_context",
"//tensorflow/c:conversion_macros", "//tensorflow/c:conversion_macros",
], ],
) )
@ -277,7 +336,7 @@ cc_library(
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
":operation_interface", ":immediate_execution_operation",
"//tensorflow/c:conversion_macros", "//tensorflow/c:conversion_macros",
], ],
) )
@ -300,7 +359,7 @@ cc_library(
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
":tensor_handle_interface", ":immediate_execution_tensor_handle",
"//tensorflow/c:conversion_macros", "//tensorflow/c:conversion_macros",
], ],
) )
@ -480,6 +539,9 @@ tf_cuda_library(
":tfe_context_internal", ":tfe_context_internal",
":tfe_op_internal", ":tfe_op_internal",
":tfe_tensorhandle_internal", ":tfe_tensorhandle_internal",
":abstract_operation",
":abstract_context",
":abstract_tensor_handle",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal", "//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",

View File

@ -0,0 +1,69 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_
#include <vector>
#include "tensorflow/c/eager/abstract_function.h"
#include "tensorflow/c/eager/abstract_operation.h"
namespace tensorflow {
// Abstract interface to a context.
//
// This serves as a factory for creating `AbstractOperation`s and for
// registering traced functions.
// Operations creation within a context can only be executed in that context
// (for now at least).
// Implementations of the context may contain some state e.g. an execution
// environment, a traced representation etc.
class AbstractContext {
protected:
enum AbstractContextKind { kTracing, kImmediateExecution };
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
virtual ~AbstractContext() {}
public:
AbstractContextKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
virtual void Release() = 0;
// Creates an operation builder and ties it to this context.
// The returned object can be used for setting operation's attributes,
// adding inputs and finally executing (immediately or lazily as in tracing)
// it in this context.
virtual AbstractOperation* CreateOperation() = 0;
// Registers a function with this context, after this the function is
// available to be called/referenced by its name in this context.
virtual Status RegisterFunction(AbstractFunction*) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
private:
const AbstractContextKind kind_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// A traced function: this hides the complexity of converting the serialized
// representation between various supported formats e.g. FunctionDef and Mlir
// function.
class AbstractFunction {
protected:
enum AbstractFunctionKind { kGraphFunc, kMlirFunc };
explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {}
public:
// Returns which subclass is this instance of.
AbstractFunctionKind getKind() const { return kind_; }
virtual ~AbstractFunction() = default;
// Returns the AbstractFunction as a FunctionDef.
virtual Status GetFunctionDef(FunctionDef**) = 0;
private:
const AbstractFunctionKind kind_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_

View File

@ -12,24 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ #ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ #define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h" #include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
struct TFE_Op;
namespace tensorflow { namespace tensorflow {
// Abstract interface to an operation. // Abstract interface to an operation.
class AbstractOperationInterface { // This interface allows building and executing an operation in either
// tracing or immediate execution mode.
class AbstractOperation {
protected:
enum AbstractOperationKind { kTracing, kImmediateExecution };
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
virtual ~AbstractOperation() {}
public: public:
AbstractOperationKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object. // Release any underlying resources, including the interface object.
// //
// WARNING: The destructor of this class is marked as protected to disallow // WARNING: The destructor of this class is marked as protected to disallow
@ -38,7 +43,6 @@ class AbstractOperationInterface {
// clients MUST call Release() in order to destroy an instance of this class. // clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0; virtual void Release() = 0;
virtual void Clear() = 0;
virtual Status Reset(const char* op, const char* raw_device_name) = 0; virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const string& Name() const = 0; virtual const string& Name() const = 0;
@ -66,12 +70,10 @@ class AbstractOperationInterface {
// existing and given constraints will be performed. // existing and given constraints will be performed.
virtual Status SetDeviceName(const char* name) = 0; virtual Status SetDeviceName(const char* name) = 0;
virtual Status AddInput(AbstractTensorHandleInterface* input) = 0; virtual Status AddInput(AbstractTensorHandle* input) = 0;
virtual Status AddInputList( virtual Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) = 0;
absl::Span<AbstractTensorHandleInterface*> inputs) = 0; virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
virtual Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
int* num_retvals) = 0; int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual Status SetAttrString(const char* attr_name, const char* data, virtual Status SetAttrString(const char* attr_name, const char* data,
size_t length) = 0; size_t length) = 0;
@ -82,7 +84,7 @@ class AbstractOperationInterface {
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims, virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) = 0; const int num_dims) = 0;
virtual Status SetAttrFunction(const char* attr_name, virtual Status SetAttrFunction(const char* attr_name,
const AbstractOperationInterface* value) = 0; const AbstractOperation* value) = 0;
virtual Status SetAttrFunctionName(const char* attr_name, const char* value, virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) = 0; size_t length) = 0;
virtual Status SetAttrTensor(const char* attr_name, virtual Status SetAttrTensor(const char* attr_name,
@ -102,19 +104,12 @@ class AbstractOperationInterface {
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims, virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) = 0; const int* num_dims, int num_values) = 0;
virtual Status SetAttrFunctionList( virtual Status SetAttrFunctionList(
const char* attr_name, const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
absl::Span<const AbstractOperationInterface*> values) = 0;
virtual Status InputLength(const char* input_name, int* length) = 0; private:
virtual Status OutputLength(const char* output_name, int* length) = 0; const AbstractOperationKind kind_;
// Experimental
virtual Status SetUseXla(bool enable) = 0;
protected:
virtual ~AbstractOperationInterface() {}
}; };
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ #endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
// execution mode.
class AbstractTensorHandle {
protected:
enum AbstractTensorHandleKind { kTracing, kImmediateExecution };
explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {}
virtual ~AbstractTensorHandle() {}
public:
AbstractTensorHandleKind getKind() const { return kind_; }
// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
private:
const AbstractTensorHandleKind kind_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_

View File

@ -21,6 +21,8 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "tensorflow/c/eager/abstract_tensor_handle.h"
// clang-format off // clang-format off
#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/platform.h"
// clang-format on // clang-format on
@ -31,8 +33,8 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
@ -1119,7 +1121,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) { TF_Status* status) {
tensorflow::AbstractOperationInterface* new_op = tensorflow::ImmediateExecutionOperation* new_op =
tensorflow::unwrap(ctx)->CreateOperation(); tensorflow::unwrap(ctx)->CreateOperation();
status->status = new_op->Reset(op_or_function_name, nullptr); status->status = new_op->Reset(op_or_function_name, nullptr);
if (!status->status.ok()) { if (!status->status.ok()) {
@ -1164,7 +1166,9 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) { TF_Status* status) {
status->status = tensorflow::unwrap(op)->AddInputList( status->status = tensorflow::unwrap(op)->AddInputList(
{tensorflow::unwrap(inputs), static_cast<size_t>(num_inputs)}); {reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(inputs)),
static_cast<size_t>(num_inputs)});
} }
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
@ -1324,7 +1328,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) { const TFE_Op** value, int num_values) {
auto s = tensorflow::unwrap(op)->SetAttrFunctionList( auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
attr_name, {tensorflow::unwrap(value), static_cast<size_t>(num_values)}); attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
tensorflow::unwrap(value)),
static_cast<size_t>(num_values)});
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Unable to set attribute: " << attr_name; LOG(WARNING) << "Unable to set attribute: " << attr_name;
} }
@ -1368,7 +1374,10 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) { TF_Status* status) {
status->status = tensorflow::unwrap(op)->Execute( status->status = tensorflow::unwrap(op)->Execute(
absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals); absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(retvals)),
*num_retvals),
num_retvals);
} }
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,

View File

@ -38,7 +38,7 @@ using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) { const char* raw_device_name, TF_Status* status) {
if (op_to_reset) { if (op_to_reset) {
tensorflow::AbstractOperationInterface* op = tensorflow::ImmediateExecutionOperation* op =
tensorflow::unwrap(op_to_reset); tensorflow::unwrap(op_to_reset);
op->Clear(); op->Clear();
status->status = op->Reset(op_or_function_name, raw_device_name); status->status = op->Reset(op_or_function_name, raw_device_name);

View File

@ -12,15 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ #ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ #define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#include <vector> #include <vector>
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h" #include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/numeric_types.h"
@ -34,16 +35,9 @@ namespace tensorflow {
// //
// A context is responsible for creating key objects such as Tensors, // A context is responsible for creating key objects such as Tensors,
// TensorHandles & Operations. // TensorHandles & Operations.
class AbstractContextInterface { class ImmediateExecutionContext : public AbstractContext {
public: public:
// Release any underlying resources, including the interface object. static constexpr AbstractContextKind kKind = kImmediateExecution;
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
virtual void Release() = 0;
// Optimized scalar creation functions // Optimized scalar creation functions
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0; virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0; virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
@ -74,15 +68,15 @@ class AbstractContextInterface {
void* memory_releaser_arg) = 0; void* memory_releaser_arg) = 0;
// Create a handle to wrap and manage a Tensor // Create a handle to wrap and manage a Tensor
virtual AbstractTensorHandleInterface* CreateLocalHandle( virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
AbstractTensorInterface* t) = 0; AbstractTensorInterface* t) = 0;
// Copy the handle to another device. // Copy the handle to another device.
virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice( virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
AbstractTensorHandleInterface* handle, const char* device_name, ImmediateExecutionTensorHandle* handle, const char* device_name,
Status* status) = 0; Status* status) = 0;
// Create an operation to perform op execution // Create an operation to perform op execution
virtual AbstractOperationInterface* CreateOperation() = 0; ImmediateExecutionOperation* CreateOperation() override = 0;
// Returns whether the runtime is backed by TFRT or the legacy TF Eager // Returns whether the runtime is backed by TFRT or the legacy TF Eager
// Runtime. This is necessary to decouple runtime-dependent // Runtime. This is necessary to decouple runtime-dependent
@ -107,14 +101,12 @@ class AbstractContextInterface {
// be executed as an op. Return error if the function with the same name // be executed as an op. Return error if the function with the same name
// already exists. // already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
protected: protected:
virtual ~AbstractContextInterface() {} ImmediateExecutionContext() : AbstractContext(kKind) {}
~ImmediateExecutionContext() override {}
}; };
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ #endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/status.h"
struct TFE_Op;
namespace tensorflow {
// Abstract interface to an operation.
class ImmediateExecutionOperation : public AbstractOperation {
public:
static constexpr AbstractOperationKind kKind = kImmediateExecution;
virtual void Clear() = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual Status SetUseXla(bool enable) = 0;
protected:
ImmediateExecutionOperation() : AbstractOperation(kKind) {}
~ImmediateExecutionOperation() override {}
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_

View File

@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ #ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ #define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h" #include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
@ -30,15 +31,9 @@ namespace tensorflow {
// files. The interface lists the common functionality that must be provided by // files. The interface lists the common functionality that must be provided by
// any concrete implementation. However, in cases where the true concrete class // any concrete implementation. However, in cases where the true concrete class
// is needed a static_cast can be applied. // is needed a static_cast can be applied.
class AbstractTensorHandleInterface { class ImmediateExecutionTensorHandle : public AbstractTensorHandle {
public: public:
// Release any underlying resources, including the interface object. static constexpr AbstractTensorHandleKind kKind = kImmediateExecution;
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus this must be allocated on the heap and
// clients MUST call Release() in order to destroy an instance of this class.
virtual void Release() = 0;
// Returns tensor dtype. // Returns tensor dtype.
virtual tensorflow::DataType DataType() const = 0; virtual tensorflow::DataType DataType() const = 0;
@ -57,12 +52,13 @@ class AbstractTensorHandleInterface {
virtual AbstractTensorInterface* Resolve(Status* status) = 0; virtual AbstractTensorInterface* Resolve(Status* status) = 0;
// Return a copy of the handle. // Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0; virtual ImmediateExecutionTensorHandle* Copy() = 0;
protected: protected:
virtual ~AbstractTensorHandleInterface() {} ImmediateExecutionTensorHandle() : AbstractTensorHandle(kKind) {}
~ImmediateExecutionTensorHandle() override {}
}; };
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ #endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_

View File

@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/context_interface.h" #include "tensorflow/c/eager/immediate_execution_context.h"
// Wraps a pointer to a context implementation. // Wraps a pointer to a context implementation.
// //
@ -28,7 +28,7 @@ typedef struct TFE_Context TFE_Context;
namespace tensorflow { namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context); DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionContext, TFE_Context);
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
// Wraps a pointer to an operation implementation. // Wraps a pointer to an operation implementation.
// //
@ -28,8 +28,8 @@ typedef struct TFE_Op TFE_Op;
namespace tensorflow { namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op); DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation, TFE_Op);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*); DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation*, TFE_Op*);
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
// Wraps a pointer to a tensor handle implementation. // Wraps a pointer to a tensor handle implementation.
// //
@ -28,9 +28,9 @@ typedef struct TFE_TensorHandle TFE_TensorHandle;
namespace tensorflow { namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface, DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle,
TFE_TensorHandle); TFE_TensorHandle);
DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*, DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle*,
TFE_TensorHandle*); TFE_TensorHandle*);
} // namespace tensorflow } // namespace tensorflow

View File

@ -23,8 +23,8 @@ cc_library(
], ],
deps = [ deps = [
":function_metadata", ":function_metadata",
"//tensorflow/c/eager:operation_interface", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:tensor_handle_interface", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
], ],
) )

View File

@ -15,12 +15,12 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
namespace tensorflow { namespace tensorflow {
const std::vector<tensorflow::AbstractTensorHandleInterface*>& const std::vector<tensorflow::ImmediateExecutionTensorHandle*>&
ConcreteFunction::GetCaptures() const { ConcreteFunction::GetCaptures() const {
return captures_; return captures_;
} }

View File

@ -18,8 +18,8 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
@ -38,15 +38,15 @@ class ConcreteFunction {
virtual ~ConcreteFunction() = 0; virtual ~ConcreteFunction() = 0;
// This method returns the "Call" Op used to execute the function. // This method returns the "Call" Op used to execute the function.
virtual AbstractOperationInterface* GetCallOp() = 0; virtual ImmediateExecutionOperation* GetCallOp() = 0;
const std::vector<tensorflow::AbstractTensorHandleInterface*>& GetCaptures() const std::vector<tensorflow::ImmediateExecutionTensorHandle*>& GetCaptures()
const; const;
const FunctionMetadata& GetFunctionMetadata() const; const FunctionMetadata& GetFunctionMetadata() const;
private: private:
FunctionMetadata metadata_; FunctionMetadata metadata_;
std::vector<tensorflow::AbstractTensorHandleInterface*> captures_; std::vector<tensorflow::ImmediateExecutionTensorHandle*> captures_;
FunctionDef* function_; FunctionDef* function_;
}; };

View File

@ -20,7 +20,7 @@ cc_library(
"owned_eager_op.h", "owned_eager_op.h",
], ],
deps = [ deps = [
"//tensorflow/c/eager:operation_interface", "//tensorflow/c/eager:immediate_execution_operation",
], ],
) )
@ -30,7 +30,7 @@ cc_library(
"owned_tensor_handle.h", "owned_tensor_handle.h",
], ],
deps = [ deps = [
"//tensorflow/c/eager:tensor_handle_interface", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/common_runtime/eager:tensor_handle",
], ],
) )
@ -39,7 +39,7 @@ cc_library(
name = "owned_eager_context", name = "owned_eager_context",
hdrs = ["owned_eager_context.h"], hdrs = ["owned_eager_context.h"],
deps = [ deps = [
"//tensorflow/c/eager:context_interface", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:context",
], ],
) )
@ -63,8 +63,9 @@ cc_library(
deps = [ deps = [
":owned_eager_op", ":owned_eager_op",
":owned_tensor_handle", ":owned_tensor_handle",
"//tensorflow/c/eager:context_interface", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:tensor_handle_interface", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",

View File

@ -18,14 +18,14 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/c/eager/context_interface.h" #include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
namespace tensorflow { namespace tensorflow {
namespace internal { namespace internal {
struct AbstractContextInterfaceDeleter { struct ImmediateExecutionContextDeleter {
void operator()(AbstractContextInterface* p) const { void operator()(ImmediateExecutionContext* p) const {
if (p != nullptr) { if (p != nullptr) {
p->Release(); p->Release();
} }
@ -43,8 +43,8 @@ struct EagerContextDeleter {
} // namespace internal } // namespace internal
using AbstractContextPtr = using AbstractContextPtr =
std::unique_ptr<AbstractContextInterface, std::unique_ptr<ImmediateExecutionContext,
internal::AbstractContextInterfaceDeleter>; internal::ImmediateExecutionContextDeleter>;
using EagerContextPtr = using EagerContextPtr =
std::unique_ptr<EagerContext, internal::EagerContextDeleter>; std::unique_ptr<EagerContext, internal::EagerContextDeleter>;

View File

@ -18,13 +18,13 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
namespace tensorflow { namespace tensorflow {
namespace internal { namespace internal {
struct AbstractOperationInterfaceDeleter { struct ImmediateExecutionOperationDeleter {
void operator()(AbstractOperationInterface* p) const { void operator()(ImmediateExecutionOperation* p) const {
if (p != nullptr) { if (p != nullptr) {
p->Release(); p->Release();
} }
@ -34,8 +34,8 @@ struct AbstractOperationInterfaceDeleter {
} // namespace internal } // namespace internal
using AbstractOpPtr = using AbstractOpPtr =
std::unique_ptr<AbstractOperationInterface, std::unique_ptr<ImmediateExecutionOperation,
internal::AbstractOperationInterfaceDeleter>; internal::ImmediateExecutionOperationDeleter>;
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
namespace tensorflow { namespace tensorflow {
@ -33,7 +33,7 @@ struct TensorHandleDeleter {
}; };
struct AbstractTensorHandleDeleter { struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandleInterface* p) const { void operator()(ImmediateExecutionTensorHandle* p) const {
if (p != nullptr) { if (p != nullptr) {
p->Release(); p->Release();
} }
@ -46,7 +46,7 @@ using TensorHandlePtr =
std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>; std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>;
using AbstractTensorHandlePtr = using AbstractTensorHandlePtr =
std::unique_ptr<AbstractTensorHandleInterface, std::unique_ptr<ImmediateExecutionTensorHandle,
internal::AbstractTensorHandleDeleter>; internal::AbstractTensorHandleDeleter>;
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,7 +16,8 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h" #include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/context_interface.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h" #include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
@ -32,7 +33,7 @@ namespace internal {
static const char kNoSharingResourceID[] = static const char kNoSharingResourceID[] =
"cd2c89b7-88b7-44c8-ad83-06c2a9158347"; "cd2c89b7-88b7-44c8-ad83-06c2a9158347";
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape, DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle) { AbstractTensorHandlePtr* handle) {
AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation()); AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation());
@ -50,17 +51,20 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString( TF_RETURN_IF_ERROR(varhandle_op->SetAttrString(
"shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID))); "shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID)));
AbstractTensorHandleInterface* var_handle = nullptr; AbstractTensorHandle* var_handle = nullptr;
int num_retvals = 1; int num_retvals = 1;
TF_RETURN_IF_ERROR(varhandle_op->Execute( TF_RETURN_IF_ERROR(varhandle_op->Execute(
absl::MakeSpan(&var_handle, num_retvals), &num_retvals)); absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
handle->reset(var_handle); if (var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) {
return errors::Internal("Unexpected tensor handle kind.");
}
handle->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(var_handle));
return Status(); return Status();
} }
Status AssignVariable(AbstractContextInterface* ctx, Status AssignVariable(ImmediateExecutionContext* ctx,
AbstractTensorHandleInterface* variable_handle, ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, AbstractTensorHandleInterface* value) { DataType dtype, ImmediateExecutionTensorHandle* value) {
AbstractOpPtr assign_op(ctx->CreateOperation()); AbstractOpPtr assign_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr)); TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype)); TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
@ -72,24 +76,27 @@ Status AssignVariable(AbstractContextInterface* ctx,
return Status(); return Status();
} }
Status ReadVariable(AbstractContextInterface* ctx, Status ReadVariable(ImmediateExecutionContext* ctx,
AbstractTensorHandleInterface* variable_handle, ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, AbstractTensorHandlePtr* output) { DataType dtype, AbstractTensorHandlePtr* output) {
AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation()); AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation());
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr)); TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype)); TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle)); TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
AbstractTensorHandleInterface* value = nullptr; AbstractTensorHandle* value = nullptr;
int num_retvals = 1; int num_retvals = 1;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals)); read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
output->reset(value); if (value->getKind() != ImmediateExecutionTensorHandle::kKind) {
return errors::Internal("Unexpected tensor handle kind.");
}
output->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(value));
return Status(); return Status();
} }
Status DestroyResource(AbstractContextInterface* ctx, Status DestroyResource(ImmediateExecutionContext* ctx,
AbstractTensorHandleInterface* handle) { ImmediateExecutionTensorHandle* handle) {
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation()); AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr)); TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true)); TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
#include "tensorflow/c/eager/context_interface.h" #include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
@ -30,7 +30,7 @@ namespace internal {
// TensorHandle associated with the variable. This is equivalent to creating an // TensorHandle associated with the variable. This is equivalent to creating an
// unitialized TF2 tf.Variable. // unitialized TF2 tf.Variable.
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872 // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx,
DataType dtype, TensorShape shape, DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle); AbstractTensorHandlePtr* handle);
@ -39,22 +39,22 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
// underlying variable for `variable_handle`. Note that it is illegal to assign // underlying variable for `variable_handle`. Note that it is illegal to assign
// a variable to a Tensor with a different dtype than what the variable was // a variable to a Tensor with a different dtype than what the variable was
// created with. // created with.
Status AssignVariable(AbstractContextInterface* ctx, Status AssignVariable(ImmediateExecutionContext* ctx,
AbstractTensorHandleInterface* variable_handle, ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, AbstractTensorHandleInterface* value); DataType dtype, ImmediateExecutionTensorHandle* value);
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable // Executes a ReadVariableOp using `ctx`. This reads the underlying variable
// value of `variable_handle` and copies the value to `output`. `dtype` must be // value of `variable_handle` and copies the value to `output`. `dtype` must be
// the dtype of the variable associated with `variable_handle`. // the dtype of the variable associated with `variable_handle`.
Status ReadVariable(AbstractContextInterface* ctx, Status ReadVariable(ImmediateExecutionContext* ctx,
AbstractTensorHandleInterface* variable_handle, ImmediateExecutionTensorHandle* variable_handle,
DataType dtype, AbstractTensorHandlePtr* output); DataType dtype, AbstractTensorHandlePtr* output);
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to // Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter: // the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290 // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290
Status DestroyResource(AbstractContextInterface* ctx, Status DestroyResource(ImmediateExecutionContext* ctx,
AbstractTensorHandleInterface* handle); ImmediateExecutionTensorHandle* handle);
} // namespace internal } // namespace internal
} // namespace tensorflow } // namespace tensorflow

View File

@ -178,7 +178,7 @@ cc_library(
":tensorhandle_list_type", ":tensorhandle_list_type",
"//tensorflow/c:c_api_macros", "//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:tensor_handle_interface", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/c/eager:tfe_tensorhandle_internal",
], ],
) )
@ -190,7 +190,7 @@ cc_library(
], ],
deps = [ deps = [
"//tensorflow/c:conversion_macros", "//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:tensor_handle_interface", "//tensorflow/c/eager:immediate_execution_tensor_handle",
], ],
) )

View File

@ -17,7 +17,7 @@ limitations under the License.
#include <stddef.h> #include <stddef.h>
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" #include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"

View File

@ -19,7 +19,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/c/conversion_macros.h" #include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
// Internal structures used by the SavedModel C API. These are likely to // Internal structures used by the SavedModel C API. These are likely to
// change and should not be depended on. // change and should not be depended on.
@ -29,7 +29,7 @@ typedef struct TF_TensorHandleList TF_TensorHandleList;
namespace tensorflow { namespace tensorflow {
DEFINE_CONVERSION_FUNCTIONS( DEFINE_CONVERSION_FUNCTIONS(
std::vector<tensorflow::AbstractTensorHandleInterface*>, std::vector<tensorflow::ImmediateExecutionTensorHandle*>,
TF_TensorHandleList) TF_TensorHandleList)
} // namespace tensorflow } // namespace tensorflow

View File

@ -32,6 +32,8 @@ tf_cuda_library(
":tensor_handle", ":tensor_handle",
"//tensorflow/c:c_api_internal", "//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:abstract_function",
"//tensorflow/core/platform:errors",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -74,9 +76,9 @@ tf_cuda_library(
":kernel_and_device", ":kernel_and_device",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:context_interface", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:tensor_handle_interface", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:operation_interface", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_env",
] + select({ ] + select({
@ -137,8 +139,10 @@ tf_cuda_library(
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant", "@com_google_absl//absl/types:variant",
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:operation_interface", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:tensor_handle_interface", "//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -211,7 +215,7 @@ tf_cuda_library(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant", "@com_google_absl//absl/types:variant",
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:tensor_handle_interface", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:core_cpu_lib", "//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
@ -496,6 +500,8 @@ cc_library(
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
"//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:common",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",
"//tensorflow/c/eager:abstract_function",
"//tensorflow/core/platform:errors",
] + select({ ] + select({
"//tensorflow:android": [ "//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite", "//tensorflow/core:portable_tensorflow_lib_lite",

View File

@ -30,8 +30,6 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/collective_executor_mgr.h" #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h" #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include "tensorflow/core/common_runtime/colocation_graph.h" #include "tensorflow/core/common_runtime/colocation_graph.h"

View File

@ -33,7 +33,7 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "tensorflow/c/eager/context_interface.h" #include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/composite_device.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
@ -135,7 +135,7 @@ class CustomDevice {
// TensorHandles may be placed either on custom or physical devices. // TensorHandles may be placed either on custom or physical devices.
using VariantDevice = absl::variant<Device*, CustomDevice*>; using VariantDevice = absl::variant<Device*, CustomDevice*>;
class EagerContext : public AbstractContextInterface, public core::RefCounted { class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
public: public:
static constexpr uint64 kInvalidContextId = 0; static constexpr uint64 kInvalidContextId = 0;
@ -178,12 +178,14 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
MemoryReleaser memory_releaser, MemoryReleaser memory_releaser,
void* memory_releaser_arg) override; void* memory_releaser_arg) override;
AbstractTensorHandleInterface* CreateLocalHandle( ImmediateExecutionTensorHandle* CreateLocalHandle(
AbstractTensorInterface* t) override; AbstractTensorInterface* t) override;
AbstractTensorHandleInterface* CopyTensorHandleToDevice( ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
AbstractTensorHandleInterface* handle, const char* device_name, ImmediateExecutionTensorHandle* handle, const char* device_name,
Status* status) override; Status* status) override;
AbstractOperationInterface* CreateOperation() override; ImmediateExecutionOperation* CreateOperation() override;
Status RegisterFunction(AbstractFunction* f) override;
bool UsesTFRT() override; bool UsesTFRT() override;
@ -716,7 +718,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
std::function<void()> resource_deallocator_ = nullptr; std::function<void()> resource_deallocator_ = nullptr;
}; };
inline EagerContext* ContextFromInterface(AbstractContextInterface* context) { inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) {
return down_cast<EagerContext*>(context); return down_cast<EagerContext*>(context);
} }

View File

@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/abstract_function.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/platform/errors.h"
namespace { namespace {
@ -112,8 +114,8 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
} }
} }
AbstractTensorHandleInterface* EagerContext::CopyTensorHandleToDevice( ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
AbstractTensorHandleInterface* handle, const char* device_name, ImmediateExecutionTensorHandle* handle, const char* device_name,
Status* status) { Status* status) {
TensorHandle* input = TensorHandleFromInterface(handle); TensorHandle* input = TensorHandleFromInterface(handle);
TensorHandle* result = nullptr; TensorHandle* result = nullptr;
@ -158,7 +160,7 @@ AbstractTensorHandleInterface* EagerContext::CopyTensorHandleToDevice(
// here to a circular BUILD dep issue. If we move this to context.cc, then we // here to a circular BUILD dep issue. If we move this to context.cc, then we
// will have the circular dependency of: // will have the circular dependency of:
// context -> tensor_handle -> remote_tensor_handle_data -> context // context -> tensor_handle -> remote_tensor_handle_data -> context
AbstractTensorHandleInterface* EagerContext::CreateLocalHandle( ImmediateExecutionTensorHandle* EagerContext::CreateLocalHandle(
AbstractTensorInterface* t) { AbstractTensorInterface* t) {
Tensor tensor = TensorFromInterface(t); Tensor tensor = TensorFromInterface(t);
return TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/HostCPU(), return TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/HostCPU(),
@ -168,14 +170,23 @@ AbstractTensorHandleInterface* EagerContext::CreateLocalHandle(
// TODO(b/152902651): We have to keep this function here since EagerOperation // TODO(b/152902651): We have to keep this function here since EagerOperation
// depends on EagerContext. Thus, the context build target can't depend on // depends on EagerContext. Thus, the context build target can't depend on
// EagerOperation. // EagerOperation.
AbstractOperationInterface* EagerContext::CreateOperation() { ImmediateExecutionOperation* EagerContext::CreateOperation() {
return new EagerOperation(this); return new EagerOperation(this);
} }
Status EagerContext::RegisterFunction(AbstractFunction* f) {
FunctionDef* fdef;
TF_RETURN_IF_ERROR(f->GetFunctionDef(&fdef));
if (!fdef) {
return errors::InvalidArgument("GetFunctionDef returned nullptr.");
}
return AddFunctionDef(*fdef);
}
// TODO(b/152902651): Once we move many execute.cc functions into // TODO(b/152902651): Once we move many execute.cc functions into
// eager_operation.cc we can avoid a circular dependency between them. // eager_operation.cc we can avoid a circular dependency between them.
Status EagerOperation::Execute( Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
absl::Span<AbstractTensorHandleInterface*> retvals, int* num_retvals) { int* num_retvals) {
return EagerExecute( return EagerExecute(
this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()), this, reinterpret_cast<tensorflow::TensorHandle**>(retvals.data()),
num_retvals); num_retvals);

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
@ -91,8 +93,8 @@ Status EagerOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
return Status::OK(); return Status::OK();
} }
Status EagerOperation::SetAttrFunction( Status EagerOperation::SetAttrFunction(const char* attr_name,
const char* attr_name, const AbstractOperationInterface* value) { const AbstractOperation* value) {
AttrValue attr_value; AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func(); NameAttrList* func = attr_value.mutable_func();
func->set_name(value->Name()); func->set_name(value->Name());
@ -194,8 +196,7 @@ Status EagerOperation::SetAttrShapeList(const char* attr_name,
} }
Status EagerOperation::SetAttrFunctionList( Status EagerOperation::SetAttrFunctionList(
const char* attr_name, const char* attr_name, absl::Span<const AbstractOperation*> values) {
absl::Span<const AbstractOperationInterface*> values) {
size_t num_values = values.size(); size_t num_values = values.size();
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]); std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) { for (int i = 0; i < num_values; i++) {
@ -253,14 +254,13 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) {
return Status::OK(); return Status::OK();
} }
Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) { Status EagerOperation::AddInput(AbstractTensorHandle* input) {
TensorHandle* h = TensorHandleFromInterface(input); TensorHandle* h = TensorHandleFromInterface(input);
AddTensorHandle(h); AddTensorHandle(h);
return MaybeInferSingleInputAttrs(h); return MaybeInferSingleInputAttrs(h);
} }
Status EagerOperation::AddInputList( Status EagerOperation::AddInputList(absl::Span<AbstractTensorHandle*> inputs) {
absl::Span<AbstractTensorHandleInterface*> inputs) {
for (auto& input : inputs) { for (auto& input : inputs) {
TensorHandle* h = TensorHandleFromInterface(input); TensorHandle* h = TensorHandleFromInterface(input);
AddTensorHandle(h); AddTensorHandle(h);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "absl/types/variant.h" #include "absl/types/variant.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h"
@ -31,7 +32,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
class EagerOperation : public AbstractOperationInterface { class EagerOperation : public ImmediateExecutionOperation {
public: public:
explicit EagerOperation(tensorflow::EagerContext* ctx) : ctx_(*ctx) {} explicit EagerOperation(tensorflow::EagerContext* ctx) : ctx_(*ctx) {}
~EagerOperation() override { ~EagerOperation() override {
@ -56,7 +57,7 @@ class EagerOperation : public AbstractOperationInterface {
} }
// Replaces the previous device name with the given one (see // Replaces the previous device name with the given one (see
// AbstractOperationInterface::SetDeviceName for more details). // AbstractOperation::SetDeviceName for more details).
// //
// This also resets the internal device pointer, unless the given name refers // This also resets the internal device pointer, unless the given name refers
// to a known custom device, in which case the internal device pointer is // to a known custom device, in which case the internal device pointer is
@ -76,10 +77,9 @@ class EagerOperation : public AbstractOperationInterface {
Status SetAttrValue(const char* attr_name, const AttrValue& value); Status SetAttrValue(const char* attr_name, const AttrValue& value);
Status AddInput(AbstractTensorHandleInterface* input) override; Status AddInput(AbstractTensorHandle* input) override;
Status AddInputList( Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override;
absl::Span<AbstractTensorHandleInterface*> inputs) override; Status Execute(absl::Span<AbstractTensorHandle*> retvals,
Status Execute(absl::Span<AbstractTensorHandleInterface*> retvals,
int* num_retvals) override; int* num_retvals) override;
const tensorflow::OpDef* OpDef() const override { return op_def_; }; const tensorflow::OpDef* OpDef() const override { return op_def_; };
@ -92,7 +92,7 @@ class EagerOperation : public AbstractOperationInterface {
Status SetAttrShape(const char* attr_name, const int64_t* dims, Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override; const int num_dims) override;
Status SetAttrFunction(const char* attr_name, Status SetAttrFunction(const char* attr_name,
const AbstractOperationInterface* value) override; const AbstractOperation* value) override;
Status SetAttrFunctionName(const char* attr_name, const char* data, Status SetAttrFunctionName(const char* attr_name, const char* data,
size_t length) override; size_t length) override;
Status SetAttrTensor(const char* attr_name, Status SetAttrTensor(const char* attr_name,
@ -111,7 +111,7 @@ class EagerOperation : public AbstractOperationInterface {
const int* num_dims, int num_values) override; const int* num_dims, int num_values) override;
Status SetAttrFunctionList( Status SetAttrFunctionList(
const char* attr_name, const char* attr_name,
absl::Span<const AbstractOperationInterface*> values) override; absl::Span<const AbstractOperation*> values) override;
Status InputLength(const char* input_name, int* length) override; Status InputLength(const char* input_name, int* length) override;
Status OutputLength(const char* output_name, int* length) override; Status OutputLength(const char* output_name, int* length) override;
@ -235,7 +235,7 @@ inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
} }
inline EagerOperation* OperationFromInterface( inline EagerOperation* OperationFromInterface(
AbstractOperationInterface* operation) { ImmediateExecutionOperation* operation) {
return down_cast<EagerOperation*>(operation); return down_cast<EagerOperation*>(operation);
} }

View File

@ -1071,7 +1071,7 @@ const char* TensorHandle::BackingDeviceName(Status* status) const {
} }
} }
tensorflow::AbstractTensorHandleInterface* TensorHandle::Copy() { tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() {
Ref(); Ref();
return this; return this;
} }

View File

@ -31,7 +31,7 @@ limitations under the License.
// clang-format on // clang-format on
#include "absl/types/variant.h" #include "absl/types/variant.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
@ -53,7 +53,7 @@ class EagerContext;
// Associates a Tensor and a Device, used in the eager runtime. Internal version // Associates a Tensor and a Device, used in the eager runtime. Internal version
// of the TFE_TensorHandle struct and the python EagerTensor class // of the TFE_TensorHandle struct and the python EagerTensor class
// (unrelated to python TensorHandle). // (unrelated to python TensorHandle).
class TensorHandle : public AbstractTensorHandleInterface, class TensorHandle : public ImmediateExecutionTensorHandle,
public core::RefCounted { public core::RefCounted {
// TensorHandle for dtype != DT_RESOURCE // TensorHandle for dtype != DT_RESOURCE
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
@ -121,7 +121,7 @@ class TensorHandle : public AbstractTensorHandleInterface,
const char* BackingDeviceName(Status* status) const override; const char* BackingDeviceName(Status* status) const override;
AbstractTensorInterface* Resolve(Status* status) override; AbstractTensorInterface* Resolve(Status* status) override;
AbstractTensorHandleInterface* Copy() override; ImmediateExecutionTensorHandle* Copy() override;
// Return the Tensor from the default device. // Return the Tensor from the default device.
Status Tensor(const tensorflow::Tensor** t) const; Status Tensor(const tensorflow::Tensor** t) const;
@ -372,12 +372,12 @@ const VariantDevice kVariantDeviceNull = static_cast<Device*>(nullptr);
// Returns the device backing the resource. Else, returns nullptr. // Returns the device backing the resource. Else, returns nullptr.
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx); Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
class TensorHandleInterface : public AbstractTensorHandleInterface { class TensorHandleInterface : public ImmediateExecutionTensorHandle {
public: public:
}; };
inline TensorHandle* TensorHandleFromInterface( template <typename T>
AbstractTensorHandleInterface* handle) { inline TensorHandle* TensorHandleFromInterface(T* handle) {
return down_cast<TensorHandle*>(handle); return down_cast<TensorHandle*>(handle);
} }

View File

@ -2008,7 +2008,7 @@ bool ListContainsNone(PyObject* list) {
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) { if (EagerTensor_CheckExact(tensor)) {
tensorflow::AbstractTensorHandleInterface* handle = tensorflow::ImmediateExecutionTensorHandle* handle =
tensorflow::unwrap(EagerTensor_Handle(tensor)); tensorflow::unwrap(EagerTensor_Handle(tensor));
tensorflow::int64 id = PyEagerTensor_ID(tensor); tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::DataType dtype = tensorflow::DataType dtype =
@ -3869,7 +3869,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
bool include_tensor_ranks_only, bool include_tensor_ranks_only,
EncodeResult* result) { EncodeResult* result) {
if (EagerTensor_CheckExact(arg)) { if (EagerTensor_CheckExact(arg)) {
tensorflow::AbstractTensorHandleInterface* handle = tensorflow::ImmediateExecutionTensorHandle* handle =
tensorflow::unwrap(EagerTensor_Handle(arg)); tensorflow::unwrap(EagerTensor_Handle(arg));
absl::StrAppend(&result->str, kDType, absl::StrAppend(&result->str, kDType,