diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index e9e6d470c68..831c6a0ad40 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -624,7 +624,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, const int num_inputs = input_shapes->num_items; NodeDef node_def; - tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op); + tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op); node_def.set_name(op->Name()); node_def.set_op(op->Name()); for (int i = 0; i < num_inputs; ++i) { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 9d3c79e0ae7..5f7ab4a1f59 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -38,9 +38,10 @@ tf_cuda_library( "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ - ":context_interface", - ":operation_interface", - ":tensor_handle_interface", + ":immediate_execution_context", + ":immediate_execution_operation", + ":immediate_execution_tensor_handle", + ":abstract_tensor_handle", ":tfe_context_internal", ":tfe_cancellation_manager_internal", ":tfe_executor_internal", @@ -101,13 +102,17 @@ tf_cuda_library( filegroup( name = "pywrap_required_hdrs", srcs = [ + "abstract_context.h", + "abstract_function.h", + "abstract_operation.h", + "abstract_tensor_handle.h", "c_api_experimental.h", "c_api_internal.h", "c_api_unified_experimental.h", - "context_interface.h", "dlpack.h", - "operation_interface.h", - "tensor_handle_interface.h", + "immediate_execution_context.h", + "immediate_execution_operation.h", + "immediate_execution_tensor_handle.h", "tfe_cancellation_manager_internal.h", "tfe_executor_internal.h", "tfe_monitoring_internal.h", @@ -163,12 +168,22 @@ cc_library( ) cc_library( - name = "tensor_handle_interface", - hdrs = ["tensor_handle_interface.h"], + name = "abstract_tensor_handle", + hdrs = ["abstract_tensor_handle.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [], +) + +cc_library( + name = "immediate_execution_tensor_handle", + hdrs = ["immediate_execution_tensor_handle.h"], visibility = [ "//tensorflow:internal", ], deps = [ + ":abstract_tensor_handle", "//tensorflow/c:tensor_interface", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -177,13 +192,13 @@ cc_library( ) cc_library( - name = "operation_interface", - hdrs = ["operation_interface.h"], + name = "abstract_operation", + hdrs = ["abstract_operation.h"], visibility = [ "//tensorflow:internal", ], deps = [ - ":tensor_handle_interface", + ":abstract_tensor_handle", "//tensorflow/c:tensor_interface", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -193,14 +208,58 @@ cc_library( ) cc_library( - name = "context_interface", - hdrs = ["context_interface.h"], + name = "immediate_execution_operation", + hdrs = ["immediate_execution_operation.h"], visibility = [ "//tensorflow:internal", ], deps = [ - ":operation_interface", - ":tensor_handle_interface", + ":abstract_operation", + ":abstract_tensor_handle", + ":immediate_execution_tensor_handle", + "//tensorflow/c:tensor_interface", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "abstract_context", + hdrs = ["abstract_context.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_function", + ":abstract_operation", + ], +) + +cc_library( + name = "abstract_function", + hdrs = ["abstract_function.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:status", + ], +) + +cc_library( + name = "immediate_execution_context", + hdrs = ["immediate_execution_context.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_context", + ":immediate_execution_operation", + ":immediate_execution_tensor_handle", "//tensorflow/c:tensor_interface", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -217,7 +276,7 @@ cc_library( "//tensorflow:internal", ], deps = [ - ":context_interface", + ":immediate_execution_context", "//tensorflow/c:conversion_macros", ], ) @@ -277,7 +336,7 @@ cc_library( "//tensorflow:internal", ], deps = [ - ":operation_interface", + ":immediate_execution_operation", "//tensorflow/c:conversion_macros", ], ) @@ -300,7 +359,7 @@ cc_library( "//tensorflow:internal", ], deps = [ - ":tensor_handle_interface", + ":immediate_execution_tensor_handle", "//tensorflow/c:conversion_macros", ], ) @@ -480,6 +539,9 @@ tf_cuda_library( ":tfe_context_internal", ":tfe_op_internal", ":tfe_tensorhandle_internal", + ":abstract_operation", + ":abstract_context", + ":abstract_tensor_handle", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", diff --git a/tensorflow/c/eager/abstract_context.h b/tensorflow/c/eager/abstract_context.h new file mode 100644 index 00000000000..59c726349ac --- /dev/null +++ b/tensorflow/c/eager/abstract_context.h @@ -0,0 +1,69 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_ + +#include + +#include "tensorflow/c/eager/abstract_function.h" +#include "tensorflow/c/eager/abstract_operation.h" + +namespace tensorflow { + +// Abstract interface to a context. +// +// This serves as a factory for creating `AbstractOperation`s and for +// registering traced functions. +// Operations creation within a context can only be executed in that context +// (for now at least). +// Implementations of the context may contain some state e.g. an execution +// environment, a traced representation etc. +class AbstractContext { + protected: + enum AbstractContextKind { kTracing, kImmediateExecution }; + explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {} + virtual ~AbstractContext() {} + + public: + AbstractContextKind getKind() const { return kind_; } + + // Release any underlying resources, including the interface object. + // + // WARNING: The destructor of this class is marked as protected to disallow + // clients from directly destroying this object since it may manage it's own + // lifetime through ref counting. Thus clients MUST call Release() in order to + // destroy an instance of this class. + virtual void Release() = 0; + + // Creates an operation builder and ties it to this context. + // The returned object can be used for setting operation's attributes, + // adding inputs and finally executing (immediately or lazily as in tracing) + // it in this context. + virtual AbstractOperation* CreateOperation() = 0; + + // Registers a function with this context, after this the function is + // available to be called/referenced by its name in this context. + virtual Status RegisterFunction(AbstractFunction*) = 0; + // Remove a function. 'func' argument is the name of a previously added + // FunctionDef. The name is in fdef.signature.name. + virtual Status RemoveFunction(const string& func) = 0; + + private: + const AbstractContextKind kind_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_ diff --git a/tensorflow/c/eager/abstract_function.h b/tensorflow/c/eager/abstract_function.h new file mode 100644 index 00000000000..e322b31f2b4 --- /dev/null +++ b/tensorflow/c/eager/abstract_function.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// A traced function: this hides the complexity of converting the serialized +// representation between various supported formats e.g. FunctionDef and Mlir +// function. +class AbstractFunction { + protected: + enum AbstractFunctionKind { kGraphFunc, kMlirFunc }; + explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {} + + public: + // Returns which subclass is this instance of. + AbstractFunctionKind getKind() const { return kind_; } + virtual ~AbstractFunction() = default; + + // Returns the AbstractFunction as a FunctionDef. + virtual Status GetFunctionDef(FunctionDef**) = 0; + + private: + const AbstractFunctionKind kind_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_ diff --git a/tensorflow/c/eager/operation_interface.h b/tensorflow/c/eager/abstract_operation.h similarity index 80% rename from tensorflow/c/eager/operation_interface.h rename to tensorflow/c/eager/abstract_operation.h index 844ba6c14bd..da4b6ecb75e 100644 --- a/tensorflow/c/eager/operation_interface.h +++ b/tensorflow/c/eager/abstract_operation.h @@ -12,24 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ -#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ #include "absl/types/span.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/tensor_interface.h" -#include "tensorflow/core/framework/device_attributes.pb.h" -#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" -struct TFE_Op; - namespace tensorflow { // Abstract interface to an operation. -class AbstractOperationInterface { +// This interface allows building and executing an operation in either +// tracing or immediate execution mode. +class AbstractOperation { + protected: + enum AbstractOperationKind { kTracing, kImmediateExecution }; + explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} + virtual ~AbstractOperation() {} + public: + AbstractOperationKind getKind() const { return kind_; } + // Release any underlying resources, including the interface object. // // WARNING: The destructor of this class is marked as protected to disallow @@ -38,7 +43,6 @@ class AbstractOperationInterface { // clients MUST call Release() in order to destroy an instance of this class. virtual void Release() = 0; - virtual void Clear() = 0; virtual Status Reset(const char* op, const char* raw_device_name) = 0; virtual const string& Name() const = 0; @@ -66,12 +70,10 @@ class AbstractOperationInterface { // existing and given constraints will be performed. virtual Status SetDeviceName(const char* name) = 0; - virtual Status AddInput(AbstractTensorHandleInterface* input) = 0; - virtual Status AddInputList( - absl::Span inputs) = 0; - virtual Status Execute(absl::Span retvals, + virtual Status AddInput(AbstractTensorHandle* input) = 0; + virtual Status AddInputList(absl::Span inputs) = 0; + virtual Status Execute(absl::Span retvals, int* num_retvals) = 0; - virtual const tensorflow::OpDef* OpDef() const = 0; virtual Status SetAttrString(const char* attr_name, const char* data, size_t length) = 0; @@ -82,7 +84,7 @@ class AbstractOperationInterface { virtual Status SetAttrShape(const char* attr_name, const int64_t* dims, const int num_dims) = 0; virtual Status SetAttrFunction(const char* attr_name, - const AbstractOperationInterface* value) = 0; + const AbstractOperation* value) = 0; virtual Status SetAttrFunctionName(const char* attr_name, const char* value, size_t length) = 0; virtual Status SetAttrTensor(const char* attr_name, @@ -102,19 +104,12 @@ class AbstractOperationInterface { virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims, const int* num_dims, int num_values) = 0; virtual Status SetAttrFunctionList( - const char* attr_name, - absl::Span values) = 0; + const char* attr_name, absl::Span values) = 0; - virtual Status InputLength(const char* input_name, int* length) = 0; - virtual Status OutputLength(const char* output_name, int* length) = 0; - - // Experimental - virtual Status SetUseXla(bool enable) = 0; - - protected: - virtual ~AbstractOperationInterface() {} + private: + const AbstractOperationKind kind_; }; } // namespace tensorflow -#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ +#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ diff --git a/tensorflow/c/eager/abstract_tensor_handle.h b/tensorflow/c/eager/abstract_tensor_handle.h new file mode 100644 index 00000000000..14acac29bb9 --- /dev/null +++ b/tensorflow/c/eager/abstract_tensor_handle.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_ + +namespace tensorflow { + +// Abstract interface to a Tensor handle in either tracing or immediate +// execution mode. +class AbstractTensorHandle { + protected: + enum AbstractTensorHandleKind { kTracing, kImmediateExecution }; + explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {} + virtual ~AbstractTensorHandle() {} + + public: + AbstractTensorHandleKind getKind() const { return kind_; } + + // Release any underlying resources, including the interface object. + // + // WARNING: The destructor of this class is marked as protected to disallow + // clients from directly destroying this object since it may manage it's own + // lifetime through ref counting. Thus this must be allocated on the heap and + // clients MUST call Release() in order to destroy an instance of this class. + virtual void Release() = 0; + + private: + const AbstractTensorHandleKind kind_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index fdc91675f8b..4be3cdd7c2d 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "tensorflow/c/eager/abstract_tensor_handle.h" + // clang-format off #include "tensorflow/core/platform/platform.h" // clang-format on @@ -31,8 +33,8 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/eager/operation_interface.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" @@ -1119,7 +1121,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { - tensorflow::AbstractOperationInterface* new_op = + tensorflow::ImmediateExecutionOperation* new_op = tensorflow::unwrap(ctx)->CreateOperation(); status->status = new_op->Reset(op_or_function_name, nullptr); if (!status->status.ok()) { @@ -1164,7 +1166,9 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status) { status->status = tensorflow::unwrap(op)->AddInputList( - {tensorflow::unwrap(inputs), static_cast(num_inputs)}); + {reinterpret_cast( + tensorflow::unwrap(inputs)), + static_cast(num_inputs)}); } TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, @@ -1324,7 +1328,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, const TFE_Op** value, int num_values) { auto s = tensorflow::unwrap(op)->SetAttrFunctionList( - attr_name, {tensorflow::unwrap(value), static_cast(num_values)}); + attr_name, {reinterpret_cast( + tensorflow::unwrap(value)), + static_cast(num_values)}); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1368,7 +1374,10 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { status->status = tensorflow::unwrap(op)->Execute( - absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals); + absl::MakeSpan(reinterpret_cast( + tensorflow::unwrap(retvals)), + *num_retvals), + num_retvals); } TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 0d71b11531b..9937fd7551f 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -38,7 +38,7 @@ using tensorflow::string; void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { if (op_to_reset) { - tensorflow::AbstractOperationInterface* op = + tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(op_to_reset); op->Clear(); status->status = op->Reset(op_or_function_name, raw_device_name); diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/immediate_execution_context.h similarity index 78% rename from tensorflow/c/eager/context_interface.h rename to tensorflow/c/eager/immediate_execution_context.h index e5a770a6826..0e3fe8cd4e1 100644 --- a/tensorflow/c/eager/context_interface.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -12,15 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ -#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ +#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ +#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ #include #include "absl/types/optional.h" #include "absl/types/span.h" -#include "tensorflow/c/eager/operation_interface.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/tensor_interface.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/numeric_types.h" @@ -34,16 +35,9 @@ namespace tensorflow { // // A context is responsible for creating key objects such as Tensors, // TensorHandles & Operations. -class AbstractContextInterface { +class ImmediateExecutionContext : public AbstractContext { public: - // Release any underlying resources, including the interface object. - // - // WARNING: The destructor of this class is marked as protected to disallow - // clients from directly destroying this object since it may manage it's own - // lifetime through ref counting. Thus clients MUST call Release() in order to - // destroy an instance of this class. - virtual void Release() = 0; - + static constexpr AbstractContextKind kKind = kImmediateExecution; // Optimized scalar creation functions virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0; virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0; @@ -74,15 +68,15 @@ class AbstractContextInterface { void* memory_releaser_arg) = 0; // Create a handle to wrap and manage a Tensor - virtual AbstractTensorHandleInterface* CreateLocalHandle( + virtual ImmediateExecutionTensorHandle* CreateLocalHandle( AbstractTensorInterface* t) = 0; // Copy the handle to another device. - virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice( - AbstractTensorHandleInterface* handle, const char* device_name, + virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( + ImmediateExecutionTensorHandle* handle, const char* device_name, Status* status) = 0; // Create an operation to perform op execution - virtual AbstractOperationInterface* CreateOperation() = 0; + ImmediateExecutionOperation* CreateOperation() override = 0; // Returns whether the runtime is backed by TFRT or the legacy TF Eager // Runtime. This is necessary to decouple runtime-dependent @@ -107,14 +101,12 @@ class AbstractContextInterface { // be executed as an op. Return error if the function with the same name // already exists. virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; - // Remove a function. 'func' argument is the name of a previously added - // FunctionDef. The name is in fdef.signature.name. - virtual Status RemoveFunction(const string& func) = 0; protected: - virtual ~AbstractContextInterface() {} + ImmediateExecutionContext() : AbstractContext(kKind) {} + ~ImmediateExecutionContext() override {} }; } // namespace tensorflow -#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ +#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h new file mode 100644 index 00000000000..31413b5b4b9 --- /dev/null +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_ +#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_ + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/status.h" + +struct TFE_Op; + +namespace tensorflow { + +// Abstract interface to an operation. +class ImmediateExecutionOperation : public AbstractOperation { + public: + static constexpr AbstractOperationKind kKind = kImmediateExecution; + virtual void Clear() = 0; + + virtual const tensorflow::OpDef* OpDef() const = 0; + + virtual Status InputLength(const char* input_name, int* length) = 0; + virtual Status OutputLength(const char* output_name, int* length) = 0; + + // Experimental + virtual Status SetUseXla(bool enable) = 0; + + protected: + ImmediateExecutionOperation() : AbstractOperation(kKind) {} + ~ImmediateExecutionOperation() override {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_ diff --git a/tensorflow/c/eager/tensor_handle_interface.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h similarity index 74% rename from tensorflow/c/eager/tensor_handle_interface.h rename to tensorflow/c/eager/immediate_execution_tensor_handle.h index 1ca40daec41..1f5a77e54ee 100644 --- a/tensorflow/c/eager/tensor_handle_interface.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ -#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ +#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_ +#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_ +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/tensor_interface.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" @@ -30,15 +31,9 @@ namespace tensorflow { // files. The interface lists the common functionality that must be provided by // any concrete implementation. However, in cases where the true concrete class // is needed a static_cast can be applied. -class AbstractTensorHandleInterface { +class ImmediateExecutionTensorHandle : public AbstractTensorHandle { public: - // Release any underlying resources, including the interface object. - // - // WARNING: The destructor of this class is marked as protected to disallow - // clients from directly destroying this object since it may manage it's own - // lifetime through ref counting. Thus this must be allocated on the heap and - // clients MUST call Release() in order to destroy an instance of this class. - virtual void Release() = 0; + static constexpr AbstractTensorHandleKind kKind = kImmediateExecution; // Returns tensor dtype. virtual tensorflow::DataType DataType() const = 0; @@ -57,12 +52,13 @@ class AbstractTensorHandleInterface { virtual AbstractTensorInterface* Resolve(Status* status) = 0; // Return a copy of the handle. - virtual AbstractTensorHandleInterface* Copy() = 0; + virtual ImmediateExecutionTensorHandle* Copy() = 0; protected: - virtual ~AbstractTensorHandleInterface() {} + ImmediateExecutionTensorHandle() : AbstractTensorHandle(kKind) {} + ~ImmediateExecutionTensorHandle() override {} }; } // namespace tensorflow -#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ +#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_ diff --git a/tensorflow/c/eager/tfe_context_internal.h b/tensorflow/c/eager/tfe_context_internal.h index 1d29bee9ee3..1f2035317fa 100644 --- a/tensorflow/c/eager/tfe_context_internal.h +++ b/tensorflow/c/eager/tfe_context_internal.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ #include "tensorflow/c/conversion_macros.h" -#include "tensorflow/c/eager/context_interface.h" +#include "tensorflow/c/eager/immediate_execution_context.h" // Wraps a pointer to a context implementation. // @@ -28,7 +28,7 @@ typedef struct TFE_Context TFE_Context; namespace tensorflow { -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionContext, TFE_Context); } // namespace tensorflow diff --git a/tensorflow/c/eager/tfe_op_internal.h b/tensorflow/c/eager/tfe_op_internal.h index 6ca7f741d16..3fe94d358b6 100644 --- a/tensorflow/c/eager/tfe_op_internal.h +++ b/tensorflow/c/eager/tfe_op_internal.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ #include "tensorflow/c/conversion_macros.h" -#include "tensorflow/c/eager/operation_interface.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" // Wraps a pointer to an operation implementation. // @@ -28,8 +28,8 @@ typedef struct TFE_Op TFE_Op; namespace tensorflow { -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op); -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation, TFE_Op); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation*, TFE_Op*); } // namespace tensorflow diff --git a/tensorflow/c/eager/tfe_tensorhandle_internal.h b/tensorflow/c/eager/tfe_tensorhandle_internal.h index 543e5f1d932..308e8c24e2c 100644 --- a/tensorflow/c/eager/tfe_tensorhandle_internal.h +++ b/tensorflow/c/eager/tfe_tensorhandle_internal.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ #include "tensorflow/c/conversion_macros.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" // Wraps a pointer to a tensor handle implementation. // @@ -28,9 +28,9 @@ typedef struct TFE_TensorHandle TFE_TensorHandle; namespace tensorflow { -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface, +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle, TFE_TensorHandle); -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*, +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle*, TFE_TensorHandle*); } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 2e817ed02e0..dbe1b6d656c 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -23,8 +23,8 @@ cc_library( ], deps = [ ":function_metadata", - "//tensorflow/c/eager:operation_interface", - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.cc b/tensorflow/c/experimental/saved_model/core/concrete_function.cc index d5da2ca9bf4..41bae4352fc 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" namespace tensorflow { -const std::vector& +const std::vector& ConcreteFunction::GetCaptures() const { return captures_; } diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index 6f8a5375277..22535641ef5 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/c/eager/operation_interface.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/core/framework/function.pb.h" @@ -38,15 +38,15 @@ class ConcreteFunction { virtual ~ConcreteFunction() = 0; // This method returns the "Call" Op used to execute the function. - virtual AbstractOperationInterface* GetCallOp() = 0; + virtual ImmediateExecutionOperation* GetCallOp() = 0; - const std::vector& GetCaptures() + const std::vector& GetCaptures() const; const FunctionMetadata& GetFunctionMetadata() const; private: FunctionMetadata metadata_; - std::vector captures_; + std::vector captures_; FunctionDef* function_; }; diff --git a/tensorflow/c/experimental/saved_model/core/ops/BUILD b/tensorflow/c/experimental/saved_model/core/ops/BUILD index aa909c692ca..8c4c41c6d75 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/BUILD +++ b/tensorflow/c/experimental/saved_model/core/ops/BUILD @@ -20,7 +20,7 @@ cc_library( "owned_eager_op.h", ], deps = [ - "//tensorflow/c/eager:operation_interface", + "//tensorflow/c/eager:immediate_execution_operation", ], ) @@ -30,7 +30,7 @@ cc_library( "owned_tensor_handle.h", ], deps = [ - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core/common_runtime/eager:tensor_handle", ], ) @@ -39,7 +39,7 @@ cc_library( name = "owned_eager_context", hdrs = ["owned_eager_context.h"], deps = [ - "//tensorflow/c/eager:context_interface", + "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/core/common_runtime/eager:context", ], ) @@ -63,8 +63,9 @@ cc_library( deps = [ ":owned_eager_op", ":owned_tensor_handle", - "//tensorflow/c/eager:context_interface", - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h b/tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h index 300059cd069..d944fcb51a2 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h +++ b/tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h @@ -18,14 +18,14 @@ limitations under the License. #include -#include "tensorflow/c/eager/context_interface.h" +#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/core/common_runtime/eager/context.h" namespace tensorflow { namespace internal { -struct AbstractContextInterfaceDeleter { - void operator()(AbstractContextInterface* p) const { +struct ImmediateExecutionContextDeleter { + void operator()(ImmediateExecutionContext* p) const { if (p != nullptr) { p->Release(); } @@ -43,8 +43,8 @@ struct EagerContextDeleter { } // namespace internal using AbstractContextPtr = - std::unique_ptr; + std::unique_ptr; using EagerContextPtr = std::unique_ptr; diff --git a/tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h b/tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h index c6b21578820..b3a08334a97 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h +++ b/tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h @@ -18,13 +18,13 @@ limitations under the License. #include -#include "tensorflow/c/eager/operation_interface.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" namespace tensorflow { namespace internal { -struct AbstractOperationInterfaceDeleter { - void operator()(AbstractOperationInterface* p) const { +struct ImmediateExecutionOperationDeleter { + void operator()(ImmediateExecutionOperation* p) const { if (p != nullptr) { p->Release(); } @@ -34,8 +34,8 @@ struct AbstractOperationInterfaceDeleter { } // namespace internal using AbstractOpPtr = - std::unique_ptr; + std::unique_ptr; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h b/tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h index e98d6554afb..c52ebaa2479 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h +++ b/tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" namespace tensorflow { @@ -33,7 +33,7 @@ struct TensorHandleDeleter { }; struct AbstractTensorHandleDeleter { - void operator()(AbstractTensorHandleInterface* p) const { + void operator()(ImmediateExecutionTensorHandle* p) const { if (p != nullptr) { p->Release(); } @@ -46,7 +46,7 @@ using TensorHandlePtr = std::unique_ptr; using AbstractTensorHandlePtr = - std::unique_ptr; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc index a3b3ace7be9..eb06662722e 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc @@ -16,7 +16,8 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h" #include "absl/types/span.h" -#include "tensorflow/c/eager/context_interface.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h" #include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -32,7 +33,7 @@ namespace internal { static const char kNoSharingResourceID[] = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; -Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, +Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, AbstractTensorHandlePtr* handle) { AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation()); @@ -50,17 +51,20 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, TF_RETURN_IF_ERROR(varhandle_op->SetAttrString( "shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID))); - AbstractTensorHandleInterface* var_handle = nullptr; + AbstractTensorHandle* var_handle = nullptr; int num_retvals = 1; TF_RETURN_IF_ERROR(varhandle_op->Execute( absl::MakeSpan(&var_handle, num_retvals), &num_retvals)); - handle->reset(var_handle); + if (var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) { + return errors::Internal("Unexpected tensor handle kind."); + } + handle->reset(reinterpret_cast(var_handle)); return Status(); } -Status AssignVariable(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* variable_handle, - DataType dtype, AbstractTensorHandleInterface* value) { +Status AssignVariable(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* variable_handle, + DataType dtype, ImmediateExecutionTensorHandle* value) { AbstractOpPtr assign_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr)); TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype)); @@ -72,24 +76,27 @@ Status AssignVariable(AbstractContextInterface* ctx, return Status(); } -Status ReadVariable(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* variable_handle, +Status ReadVariable(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* variable_handle, DataType dtype, AbstractTensorHandlePtr* output) { AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr)); TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype)); TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle)); - AbstractTensorHandleInterface* value = nullptr; + AbstractTensorHandle* value = nullptr; int num_retvals = 1; TF_RETURN_IF_ERROR( read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals)); - output->reset(value); + if (value->getKind() != ImmediateExecutionTensorHandle::kKind) { + return errors::Internal("Unexpected tensor handle kind."); + } + output->reset(reinterpret_cast(value)); return Status(); } -Status DestroyResource(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* handle) { +Status DestroyResource(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* handle) { AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr)); TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true)); diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h index 8a410328b9e..038b2c3d62a 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H -#include "tensorflow/c/eager/context_interface.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -30,7 +30,7 @@ namespace internal { // TensorHandle associated with the variable. This is equivalent to creating an // unitialized TF2 tf.Variable. // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872 -Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, +Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, AbstractTensorHandlePtr* handle); @@ -39,22 +39,22 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, // underlying variable for `variable_handle`. Note that it is illegal to assign // a variable to a Tensor with a different dtype than what the variable was // created with. -Status AssignVariable(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* variable_handle, - DataType dtype, AbstractTensorHandleInterface* value); +Status AssignVariable(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* variable_handle, + DataType dtype, ImmediateExecutionTensorHandle* value); // Executes a ReadVariableOp using `ctx`. This reads the underlying variable // value of `variable_handle` and copies the value to `output`. `dtype` must be // the dtype of the variable associated with `variable_handle`. -Status ReadVariable(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* variable_handle, +Status ReadVariable(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* variable_handle, DataType dtype, AbstractTensorHandlePtr* output); // Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to // the cleanup that occurs in a tf.Variable's EagerResourceDeleter: // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290 -Status DestroyResource(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* handle); +Status DestroyResource(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* handle); } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 72474940c16..888c284bb12 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -178,7 +178,7 @@ cc_library( ":tensorhandle_list_type", "//tensorflow/c:c_api_macros", "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:tfe_tensorhandle_internal", ], ) @@ -190,7 +190,7 @@ cc_library( ], deps = [ "//tensorflow/c:conversion_macros", - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:immediate_execution_tensor_handle", ], ) diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc index 7d018658101..c8f00c1f7c0 100644 --- a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc +++ b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h index 8cbec2806a8..566417df025 100644 --- a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h +++ b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/c/conversion_macros.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" // Internal structures used by the SavedModel C API. These are likely to // change and should not be depended on. @@ -29,7 +29,7 @@ typedef struct TF_TensorHandleList TF_TensorHandleList; namespace tensorflow { DEFINE_CONVERSION_FUNCTIONS( - std::vector, + std::vector, TF_TensorHandleList) } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index e4f4c483209..fb69bcb7ab5 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -32,6 +32,8 @@ tf_cuda_library( ":tensor_handle", "//tensorflow/c:c_api_internal", "//tensorflow/c:tf_tensor_internal", + "//tensorflow/c/eager:abstract_function", + "//tensorflow/core/platform:errors", ], alwayslink = 1, ) @@ -74,9 +76,9 @@ tf_cuda_library( ":kernel_and_device", "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_tensor_internal", - "//tensorflow/c/eager:context_interface", - "//tensorflow/c/eager:tensor_handle_interface", - "//tensorflow/c/eager:operation_interface", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", "//tensorflow/core/distributed_runtime:worker_env", ] + select({ @@ -137,8 +139,10 @@ tf_cuda_library( "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "//tensorflow/c:tf_tensor_internal", - "//tensorflow/c/eager:operation_interface", - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", @@ -211,7 +215,7 @@ tf_cuda_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", "//tensorflow/c:tf_tensor_internal", - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -496,6 +500,8 @@ cc_library( "//tensorflow/c:tf_tensor_internal", "//tensorflow/compiler/jit:common", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/c/eager:abstract_function", + "//tensorflow/core/platform:errors", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 970c2bcbb89..6dc0a3a8200 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -30,8 +30,6 @@ limitations under the License. #include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor_internal.h" -#include "tensorflow/c/eager/operation_interface.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/core/common_runtime/collective_executor_mgr.h" #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" #include "tensorflow/core/common_runtime/colocation_graph.h" diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index cb6d09f8f1d..141327c08cb 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -33,7 +33,7 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/container/flat_hash_map.h" -#include "tensorflow/c/eager/context_interface.h" +#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -135,7 +135,7 @@ class CustomDevice { // TensorHandles may be placed either on custom or physical devices. using VariantDevice = absl::variant; -class EagerContext : public AbstractContextInterface, public core::RefCounted { +class EagerContext : public ImmediateExecutionContext, public core::RefCounted { public: static constexpr uint64 kInvalidContextId = 0; @@ -178,12 +178,14 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { MemoryReleaser memory_releaser, void* memory_releaser_arg) override; - AbstractTensorHandleInterface* CreateLocalHandle( + ImmediateExecutionTensorHandle* CreateLocalHandle( AbstractTensorInterface* t) override; - AbstractTensorHandleInterface* CopyTensorHandleToDevice( - AbstractTensorHandleInterface* handle, const char* device_name, + ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( + ImmediateExecutionTensorHandle* handle, const char* device_name, Status* status) override; - AbstractOperationInterface* CreateOperation() override; + ImmediateExecutionOperation* CreateOperation() override; + + Status RegisterFunction(AbstractFunction* f) override; bool UsesTFRT() override; @@ -716,7 +718,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { std::function resource_deallocator_ = nullptr; }; -inline EagerContext* ContextFromInterface(AbstractContextInterface* context) { +inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) { return down_cast(context); } diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc index e342f6ae6cd..3d37250a4fe 100644 --- a/tensorflow/core/common_runtime/eager/core.cc +++ b/tensorflow/core/common_runtime/eager/core.cc @@ -13,11 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/abstract_function.h" #include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/platform/errors.h" namespace { @@ -112,8 +114,8 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) { } } -AbstractTensorHandleInterface* EagerContext::CopyTensorHandleToDevice( - AbstractTensorHandleInterface* handle, const char* device_name, +ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( + ImmediateExecutionTensorHandle* handle, const char* device_name, Status* status) { TensorHandle* input = TensorHandleFromInterface(handle); TensorHandle* result = nullptr; @@ -158,7 +160,7 @@ AbstractTensorHandleInterface* EagerContext::CopyTensorHandleToDevice( // here to a circular BUILD dep issue. If we move this to context.cc, then we // will have the circular dependency of: // context -> tensor_handle -> remote_tensor_handle_data -> context -AbstractTensorHandleInterface* EagerContext::CreateLocalHandle( +ImmediateExecutionTensorHandle* EagerContext::CreateLocalHandle( AbstractTensorInterface* t) { Tensor tensor = TensorFromInterface(t); return TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/HostCPU(), @@ -168,14 +170,23 @@ AbstractTensorHandleInterface* EagerContext::CreateLocalHandle( // TODO(b/152902651): We have to keep this function here since EagerOperation // depends on EagerContext. Thus, the context build target can't depend on // EagerOperation. -AbstractOperationInterface* EagerContext::CreateOperation() { +ImmediateExecutionOperation* EagerContext::CreateOperation() { return new EagerOperation(this); } +Status EagerContext::RegisterFunction(AbstractFunction* f) { + FunctionDef* fdef; + TF_RETURN_IF_ERROR(f->GetFunctionDef(&fdef)); + if (!fdef) { + return errors::InvalidArgument("GetFunctionDef returned nullptr."); + } + return AddFunctionDef(*fdef); +} + // TODO(b/152902651): Once we move many execute.cc functions into // eager_operation.cc we can avoid a circular dependency between them. -Status EagerOperation::Execute( - absl::Span retvals, int* num_retvals) { +Status EagerOperation::Execute(absl::Span retvals, + int* num_retvals) { return EagerExecute( this, reinterpret_cast(retvals.data()), num_retvals); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 090bfef46bd..073095e64d1 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "absl/types/span.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" @@ -91,8 +93,8 @@ Status EagerOperation::SetAttrShape(const char* attr_name, const int64_t* dims, return Status::OK(); } -Status EagerOperation::SetAttrFunction( - const char* attr_name, const AbstractOperationInterface* value) { +Status EagerOperation::SetAttrFunction(const char* attr_name, + const AbstractOperation* value) { AttrValue attr_value; NameAttrList* func = attr_value.mutable_func(); func->set_name(value->Name()); @@ -194,8 +196,7 @@ Status EagerOperation::SetAttrShapeList(const char* attr_name, } Status EagerOperation::SetAttrFunctionList( - const char* attr_name, - absl::Span values) { + const char* attr_name, absl::Span values) { size_t num_values = values.size(); std::unique_ptr funcs(new NameAttrList[num_values]); for (int i = 0; i < num_values; i++) { @@ -253,14 +254,13 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) { return Status::OK(); } -Status EagerOperation::AddInput(AbstractTensorHandleInterface* input) { +Status EagerOperation::AddInput(AbstractTensorHandle* input) { TensorHandle* h = TensorHandleFromInterface(input); AddTensorHandle(h); return MaybeInferSingleInputAttrs(h); } -Status EagerOperation::AddInputList( - absl::Span inputs) { +Status EagerOperation::AddInputList(absl::Span inputs) { for (auto& input : inputs) { TensorHandle* h = TensorHandleFromInterface(input); AddTensorHandle(h); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 14268ef2630..963aed25733 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" @@ -31,7 +32,7 @@ limitations under the License. namespace tensorflow { -class EagerOperation : public AbstractOperationInterface { +class EagerOperation : public ImmediateExecutionOperation { public: explicit EagerOperation(tensorflow::EagerContext* ctx) : ctx_(*ctx) {} ~EagerOperation() override { @@ -56,7 +57,7 @@ class EagerOperation : public AbstractOperationInterface { } // Replaces the previous device name with the given one (see - // AbstractOperationInterface::SetDeviceName for more details). + // AbstractOperation::SetDeviceName for more details). // // This also resets the internal device pointer, unless the given name refers // to a known custom device, in which case the internal device pointer is @@ -76,10 +77,9 @@ class EagerOperation : public AbstractOperationInterface { Status SetAttrValue(const char* attr_name, const AttrValue& value); - Status AddInput(AbstractTensorHandleInterface* input) override; - Status AddInputList( - absl::Span inputs) override; - Status Execute(absl::Span retvals, + Status AddInput(AbstractTensorHandle* input) override; + Status AddInputList(absl::Span inputs) override; + Status Execute(absl::Span retvals, int* num_retvals) override; const tensorflow::OpDef* OpDef() const override { return op_def_; }; @@ -92,7 +92,7 @@ class EagerOperation : public AbstractOperationInterface { Status SetAttrShape(const char* attr_name, const int64_t* dims, const int num_dims) override; Status SetAttrFunction(const char* attr_name, - const AbstractOperationInterface* value) override; + const AbstractOperation* value) override; Status SetAttrFunctionName(const char* attr_name, const char* data, size_t length) override; Status SetAttrTensor(const char* attr_name, @@ -111,7 +111,7 @@ class EagerOperation : public AbstractOperationInterface { const int* num_dims, int num_values) override; Status SetAttrFunctionList( const char* attr_name, - absl::Span values) override; + absl::Span values) override; Status InputLength(const char* input_name, int* length) override; Status OutputLength(const char* output_name, int* length) override; @@ -235,7 +235,7 @@ inline void EagerOperation::UpdateInput(int i, TensorHandle* h) { } inline EagerOperation* OperationFromInterface( - AbstractOperationInterface* operation) { + ImmediateExecutionOperation* operation) { return down_cast(operation); } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 9b82c556cd0..9e607c97683 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -1071,7 +1071,7 @@ const char* TensorHandle::BackingDeviceName(Status* status) const { } } -tensorflow::AbstractTensorHandleInterface* TensorHandle::Copy() { +tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() { Ref(); return this; } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 5e7638ae03c..a14df475e0f 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -31,7 +31,7 @@ limitations under the License. // clang-format on #include "absl/types/variant.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" @@ -53,7 +53,7 @@ class EagerContext; // Associates a Tensor and a Device, used in the eager runtime. Internal version // of the TFE_TensorHandle struct and the python EagerTensor class // (unrelated to python TensorHandle). -class TensorHandle : public AbstractTensorHandleInterface, +class TensorHandle : public ImmediateExecutionTensorHandle, public core::RefCounted { // TensorHandle for dtype != DT_RESOURCE TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, @@ -121,7 +121,7 @@ class TensorHandle : public AbstractTensorHandleInterface, const char* BackingDeviceName(Status* status) const override; AbstractTensorInterface* Resolve(Status* status) override; - AbstractTensorHandleInterface* Copy() override; + ImmediateExecutionTensorHandle* Copy() override; // Return the Tensor from the default device. Status Tensor(const tensorflow::Tensor** t) const; @@ -372,12 +372,12 @@ const VariantDevice kVariantDeviceNull = static_cast(nullptr); // Returns the device backing the resource. Else, returns nullptr. Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx); -class TensorHandleInterface : public AbstractTensorHandleInterface { +class TensorHandleInterface : public ImmediateExecutionTensorHandle { public: }; -inline TensorHandle* TensorHandleFromInterface( - AbstractTensorHandleInterface* handle) { +template +inline TensorHandle* TensorHandleFromInterface(T* handle) { return down_cast(handle); } diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 639f623bd1a..b9ff474caab 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -2008,7 +2008,7 @@ bool ListContainsNone(PyObject* list) { static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { - tensorflow::AbstractTensorHandleInterface* handle = + tensorflow::ImmediateExecutionTensorHandle* handle = tensorflow::unwrap(EagerTensor_Handle(tensor)); tensorflow::int64 id = PyEagerTensor_ID(tensor); tensorflow::DataType dtype = @@ -3869,7 +3869,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, bool include_tensor_ranks_only, EncodeResult* result) { if (EagerTensor_CheckExact(arg)) { - tensorflow::AbstractTensorHandleInterface* handle = + tensorflow::ImmediateExecutionTensorHandle* handle = tensorflow::unwrap(EagerTensor_Handle(arg)); absl::StrAppend(&result->str, kDType,