Attempting to split c_api_internal.h into separate files.
PiperOrigin-RevId: 307638566 Change-Id: If1cf27baadb3d3fe32808b55f19b887f490ac119
This commit is contained in:
parent
9c925a52e8
commit
2d3b485d0d
|
@ -41,12 +41,21 @@ tf_cuda_library(
|
||||||
":context_interface",
|
":context_interface",
|
||||||
":operation_interface",
|
":operation_interface",
|
||||||
":tensor_handle_interface",
|
":tensor_handle_interface",
|
||||||
|
":tfe_context_internal",
|
||||||
|
":tfe_cancellation_manager_internal",
|
||||||
|
":tfe_executor_internal",
|
||||||
|
":tfe_monitoring_internal",
|
||||||
|
":tfe_op_attrs_internal",
|
||||||
|
":tfe_op_internal",
|
||||||
|
":tfe_tensor_debug_info_internal",
|
||||||
|
":tfe_tensorhandle_internal",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/container:fixed_array",
|
"@com_google_absl//absl/container:fixed_array",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@com_google_absl//absl/types:variant",
|
"@com_google_absl//absl/types:variant",
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
|
"//tensorflow/c:tf_status_internal",
|
||||||
"//tensorflow/c:tf_tensor_internal",
|
"//tensorflow/c:tf_tensor_internal",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||||
|
@ -100,6 +109,14 @@ filegroup(
|
||||||
"dlpack.h",
|
"dlpack.h",
|
||||||
"operation_interface.h",
|
"operation_interface.h",
|
||||||
"tensor_handle_interface.h",
|
"tensor_handle_interface.h",
|
||||||
|
"tfe_cancellation_manager_internal.h",
|
||||||
|
"tfe_context_internal.h",
|
||||||
|
"tfe_executor_internal.h",
|
||||||
|
"tfe_monitoring_internal.h",
|
||||||
|
"tfe_op_attrs_internal.h",
|
||||||
|
"tfe_op_internal.h",
|
||||||
|
"tfe_tensor_debug_info_internal.h",
|
||||||
|
"tfe_tensorhandle_internal.h",
|
||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow/core:__pkg__",
|
"//tensorflow/core:__pkg__",
|
||||||
|
@ -107,33 +124,27 @@ filegroup(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
cc_library(
|
||||||
name = "c_api_internal",
|
name = "c_api_internal",
|
||||||
srcs = [
|
hdrs = [
|
||||||
"c_api_experimental.h",
|
"c_api_experimental.h",
|
||||||
"c_api_unified_experimental.h",
|
"c_api_internal.h",
|
||||||
],
|
],
|
||||||
hdrs = ["c_api_internal.h"],
|
|
||||||
visibility = [
|
visibility = [
|
||||||
"//learning/deepmind/courier:__subpackages__",
|
"//learning/deepmind/courier:__subpackages__",
|
||||||
"//tensorflow:internal",
|
"//tensorflow:internal",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":c_api",
|
":c_api",
|
||||||
":context_interface",
|
":tfe_cancellation_manager_internal",
|
||||||
":operation_interface",
|
":tfe_context_internal",
|
||||||
":tensor_handle_interface",
|
":tfe_executor_internal",
|
||||||
"//tensorflow/c:c_api",
|
":tfe_monitoring_internal",
|
||||||
|
":tfe_op_attrs_internal",
|
||||||
|
":tfe_op_internal",
|
||||||
|
":tfe_tensor_debug_info_internal",
|
||||||
|
":tfe_tensorhandle_internal",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
"//tensorflow/core:core_cpu",
|
|
||||||
"//tensorflow/core:core_cpu_lib",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:framework_lite",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal",
|
|
||||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
|
||||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -184,6 +195,99 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tfe_context_internal",
|
||||||
|
hdrs = ["tfe_context_internal.h"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":context_interface",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tfe_cancellation_manager_internal",
|
||||||
|
hdrs = ["tfe_cancellation_manager_internal.h"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tfe_executor_internal",
|
||||||
|
hdrs = ["tfe_executor_internal.h"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tfe_monitoring_internal",
|
||||||
|
hdrs = ["tfe_monitoring_internal.h"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tfe_op_attrs_internal",
|
||||||
|
hdrs = ["tfe_op_attrs_internal.h"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tfe_context_internal",
|
||||||
|
":tfe_op_internal",
|
||||||
|
"//tensorflow/c:tf_status",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tfe_op_internal",
|
||||||
|
hdrs = ["tfe_op_internal.h"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":operation_interface",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tfe_tensor_debug_info_internal",
|
||||||
|
hdrs = ["tfe_tensor_debug_info_internal.h"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tfe_tensorhandle_internal",
|
||||||
|
hdrs = ["tfe_tensorhandle_internal.h"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tensor_handle_interface",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "c_api_test_util",
|
name = "c_api_test_util",
|
||||||
testonly = 1,
|
testonly = 1,
|
||||||
|
|
|
@ -17,8 +17,11 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h"
|
||||||
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
|
#include "tensorflow/c/tf_status_internal.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||||
#include "tensorflow/compiler/jit/xla_device.h"
|
#include "tensorflow/compiler/jit/xla_device.h"
|
||||||
#endif // TENSORFLOW_EAGER_USE_XLA
|
#endif // TENSORFLOW_EAGER_USE_XLA
|
||||||
|
|
|
@ -15,39 +15,20 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||||
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstddef>
|
|
||||||
#include <map>
|
|
||||||
#include <memory>
|
|
||||||
#include <queue>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/context_interface.h"
|
#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export
|
||||||
#include "tensorflow/c/eager/operation_interface.h"
|
#include "tensorflow/c/eager/tfe_context_internal.h" // IWYU pragma: export
|
||||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export
|
||||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
#include "tensorflow/c/eager/tfe_op_internal.h" // IWYU pragma: export
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export
|
||||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" // IWYU pragma: export
|
||||||
#include "tensorflow/core/framework/cancellation.h"
|
|
||||||
#include "tensorflow/core/framework/rendezvous.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
|
||||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
|
||||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
|
||||||
#include "tensorflow/core/platform/errors.h"
|
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
|
||||||
#include "tensorflow/core/platform/stringpiece.h"
|
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
|
||||||
#include "tensorflow/core/public/version.h"
|
|
||||||
|
|
||||||
|
// TODO(b/154564140): Move this to its own header. This requires splitting
|
||||||
|
// c_api_experimental.h
|
||||||
struct TFE_ContextOptions {
|
struct TFE_ContextOptions {
|
||||||
TF_SessionOptions session_options;
|
TF_SessionOptions session_options;
|
||||||
// true if async execution is enabled.
|
// true if async execution is enabled.
|
||||||
|
@ -61,199 +42,4 @@ struct TFE_ContextOptions {
|
||||||
bool use_tfrt = false;
|
bool use_tfrt = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wraps a pointer to a context implementation.
|
|
||||||
//
|
|
||||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
|
||||||
// interface cannot destruct the underlying context object. Instead, call
|
|
||||||
// TFE_DeleteContext who calls Release() on the context pointer and deletes
|
|
||||||
// the TFE_Context structure.
|
|
||||||
struct TFE_Context {
|
|
||||||
tensorflow::AbstractContextInterface* context;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Wraps a pointer to a tensor handle implementation.
|
|
||||||
//
|
|
||||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
|
||||||
// interface cannot destruct the underlying handle object. Instead, call
|
|
||||||
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
|
|
||||||
// the TFE_TensorHandle structure.
|
|
||||||
struct TFE_TensorHandle {
|
|
||||||
tensorflow::AbstractTensorHandleInterface* handle;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_TensorDebugInfo {
|
|
||||||
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
|
|
||||||
: dev_dims(dims) {}
|
|
||||||
|
|
||||||
// Fully-padded, minor-to-major.
|
|
||||||
std::vector<tensorflow::int64> dev_dims;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Wraps a pointer to an operation implementation.
|
|
||||||
//
|
|
||||||
// WARNING: Since the underlying object could be ref-counted a user of this
|
|
||||||
// interface cannot destruct the underlying operation object. Instead, call
|
|
||||||
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
|
|
||||||
// the TFE_Op structure.
|
|
||||||
struct TFE_Op {
|
|
||||||
tensorflow::AbstractOperationInterface* operation;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_MonitoringCounterCell {
|
|
||||||
tensorflow::monitoring::CounterCell cell;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <int NumLabels>
|
|
||||||
struct TFE_MonitoringCounter {
|
|
||||||
template <typename... LabelDesc>
|
|
||||||
TFE_MonitoringCounter(const char* name, const char* description,
|
|
||||||
LabelDesc&&... label) {
|
|
||||||
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
|
|
||||||
name, description, label...));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
|
|
||||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
|
|
||||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
|
|
||||||
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_MonitoringIntGaugeCell {
|
|
||||||
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringStringGaugeCell {
|
|
||||||
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringBoolGaugeCell {
|
|
||||||
tensorflow::monitoring::GaugeCell<bool> cell;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename ValueType, int NumLabels>
|
|
||||||
struct TFE_MonitoringGauge {
|
|
||||||
template <typename... LabelDesc>
|
|
||||||
TFE_MonitoringGauge(const char* name, const char* description,
|
|
||||||
LabelDesc&&... label) {
|
|
||||||
gauge = absl::WrapUnique(
|
|
||||||
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
|
|
||||||
name, description, label...));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
|
|
||||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
|
|
||||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
|
|
||||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
|
|
||||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
|
|
||||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
|
|
||||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
|
|
||||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
|
|
||||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
|
|
||||||
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_MonitoringBuckets {
|
|
||||||
explicit TFE_MonitoringBuckets(
|
|
||||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
|
||||||
fn) {
|
|
||||||
create_buckets = fn;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
|
||||||
create_buckets;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_MonitoringSamplerCell {
|
|
||||||
tensorflow::monitoring::SamplerCell cell;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <int NumLabels>
|
|
||||||
struct TFE_MonitoringSampler {
|
|
||||||
template <typename... LabelDesc>
|
|
||||||
TFE_MonitoringSampler(
|
|
||||||
const char* name,
|
|
||||||
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
|
|
||||||
const char* description, LabelDesc&&... label) {
|
|
||||||
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
|
|
||||||
{name, description, label...}, std::move(buckets)));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
|
|
||||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
|
|
||||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
|
||||||
};
|
|
||||||
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
|
|
||||||
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
// Set an AttrValue on the op. Doesn't handle the list types.
|
|
||||||
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
|
||||||
const tensorflow::AttrValue& default_value,
|
|
||||||
const char* attr_name, TF_Status* status);
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
struct TFE_CancellationManager {
|
|
||||||
tensorflow::CancellationManager cancellation_manager;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TFE_Executor {
|
|
||||||
explicit TFE_Executor(bool async)
|
|
||||||
: owned_executor(new tensorflow::EagerExecutor(async)) {}
|
|
||||||
|
|
||||||
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
|
|
||||||
: owned_executor(nullptr), unowned_executor(executor) {}
|
|
||||||
|
|
||||||
tensorflow::EagerExecutor* executor() {
|
|
||||||
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
|
|
||||||
tensorflow::EagerExecutor* unowned_executor;
|
|
||||||
};
|
|
||||||
|
|
||||||
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
|
||||||
// that sometimes do not require serialization.
|
|
||||||
struct TFE_OpAttrs {
|
|
||||||
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
|
|
||||||
|
|
||||||
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
|
|
||||||
const char* op_name)
|
|
||||||
: name(op_name), attributes(value) {}
|
|
||||||
|
|
||||||
const char* name;
|
|
||||||
const tensorflow::AttrBuilder* attributes;
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
|
|
||||||
|
struct TFE_CancellationManager {
|
||||||
|
tensorflow::CancellationManager cancellation_manager;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_
|
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/context_interface.h"
|
||||||
|
|
||||||
|
// Wraps a pointer to a context implementation.
|
||||||
|
//
|
||||||
|
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||||
|
// interface cannot destruct the underlying context object. Instead, call
|
||||||
|
// TFE_DeleteContext who calls Release() on the context pointer and deletes
|
||||||
|
// the TFE_Context structure.
|
||||||
|
struct TFE_Context {
|
||||||
|
tensorflow::AbstractContextInterface* context;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_
|
|
@ -0,0 +1,37 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||||
|
|
||||||
|
struct TFE_Executor {
|
||||||
|
explicit TFE_Executor(bool async)
|
||||||
|
: owned_executor(new tensorflow::EagerExecutor(async)) {}
|
||||||
|
|
||||||
|
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
|
||||||
|
: owned_executor(nullptr), unowned_executor(executor) {}
|
||||||
|
|
||||||
|
tensorflow::EagerExecutor* executor() {
|
||||||
|
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
|
||||||
|
tensorflow::EagerExecutor* unowned_executor;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_
|
|
@ -0,0 +1,146 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||||
|
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||||
|
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
struct TFE_MonitoringCounterCell {
|
||||||
|
tensorflow::monitoring::CounterCell cell;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int NumLabels>
|
||||||
|
struct TFE_MonitoringCounter {
|
||||||
|
template <typename... LabelDesc>
|
||||||
|
TFE_MonitoringCounter(const char* name, const char* description,
|
||||||
|
LabelDesc&&... label) {
|
||||||
|
counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New(
|
||||||
|
name, description, label...));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> {
|
||||||
|
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> {
|
||||||
|
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> {
|
||||||
|
using TFE_MonitoringCounter::TFE_MonitoringCounter;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TFE_MonitoringIntGaugeCell {
|
||||||
|
tensorflow::monitoring::GaugeCell<tensorflow::int64> cell;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringStringGaugeCell {
|
||||||
|
tensorflow::monitoring::GaugeCell<tensorflow::string> cell;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringBoolGaugeCell {
|
||||||
|
tensorflow::monitoring::GaugeCell<bool> cell;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename ValueType, int NumLabels>
|
||||||
|
struct TFE_MonitoringGauge {
|
||||||
|
template <typename... LabelDesc>
|
||||||
|
TFE_MonitoringGauge(const char* name, const char* description,
|
||||||
|
LabelDesc&&... label) {
|
||||||
|
gauge = absl::WrapUnique(
|
||||||
|
tensorflow::monitoring::Gauge<ValueType, NumLabels>::New(
|
||||||
|
name, description, label...));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> {
|
||||||
|
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> {
|
||||||
|
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> {
|
||||||
|
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> {
|
||||||
|
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> {
|
||||||
|
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> {
|
||||||
|
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> {
|
||||||
|
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> {
|
||||||
|
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> {
|
||||||
|
using TFE_MonitoringGauge::TFE_MonitoringGauge;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TFE_MonitoringBuckets {
|
||||||
|
explicit TFE_MonitoringBuckets(
|
||||||
|
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||||
|
fn) {
|
||||||
|
create_buckets = fn;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)>
|
||||||
|
create_buckets;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TFE_MonitoringSamplerCell {
|
||||||
|
tensorflow::monitoring::SamplerCell cell;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int NumLabels>
|
||||||
|
struct TFE_MonitoringSampler {
|
||||||
|
template <typename... LabelDesc>
|
||||||
|
TFE_MonitoringSampler(
|
||||||
|
const char* name,
|
||||||
|
std::unique_ptr<tensorflow::monitoring::Buckets> buckets,
|
||||||
|
const char* description, LabelDesc&&... label) {
|
||||||
|
sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New(
|
||||||
|
{name, description, label...}, std::move(buckets)));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> {
|
||||||
|
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> {
|
||||||
|
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||||
|
};
|
||||||
|
struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> {
|
||||||
|
using TFE_MonitoringSampler::TFE_MonitoringSampler;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_
|
|
@ -0,0 +1,52 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstddef>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||||
|
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
|
||||||
|
// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways
|
||||||
|
// that sometimes do not require serialization.
|
||||||
|
struct TFE_OpAttrs {
|
||||||
|
explicit TFE_OpAttrs() : name(nullptr), attributes(nullptr) {}
|
||||||
|
|
||||||
|
explicit TFE_OpAttrs(const tensorflow::AttrBuilder* value,
|
||||||
|
const char* op_name)
|
||||||
|
: name(op_name), attributes(value) {}
|
||||||
|
|
||||||
|
const char* name;
|
||||||
|
const tensorflow::AttrBuilder* attributes;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
// Set an AttrValue on the op. Doesn't handle the list types.
|
||||||
|
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||||
|
const tensorflow::AttrValue& default_value,
|
||||||
|
const char* attr_name, TF_Status* status);
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_
|
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/operation_interface.h"
|
||||||
|
|
||||||
|
// Wraps a pointer to an operation implementation.
|
||||||
|
//
|
||||||
|
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||||
|
// interface cannot destruct the underlying operation object. Instead, call
|
||||||
|
// TFE_DeleteOp who calls Release() on the operation pointer and deletes
|
||||||
|
// the TFE_Op structure.
|
||||||
|
struct TFE_Op {
|
||||||
|
tensorflow::AbstractOperationInterface* operation;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_
|
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
struct TFE_TensorDebugInfo {
|
||||||
|
explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims)
|
||||||
|
: dev_dims(dims) {}
|
||||||
|
|
||||||
|
// Fully-padded, minor-to-major.
|
||||||
|
std::vector<tensorflow::int64> dev_dims;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_
|
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||||
|
#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||||
|
|
||||||
|
// Wraps a pointer to a tensor handle implementation.
|
||||||
|
//
|
||||||
|
// WARNING: Since the underlying object could be ref-counted a user of this
|
||||||
|
// interface cannot destruct the underlying handle object. Instead, call
|
||||||
|
// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes
|
||||||
|
// the TFE_TensorHandle structure.
|
||||||
|
struct TFE_TensorHandle {
|
||||||
|
tensorflow::AbstractTensorHandleInterface* handle;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_
|
Loading…
Reference in New Issue