STT-tensorflow/tensorflow/c/eager/immediate_execution_context.h
Brian Zhao ebf57bdfc7 Moving RAII helpers for TensorHandle, Tensor, and Operation to their respective classes.
PiperOrigin-RevId: 317578771
Change-Id: Iaf674696ea7d7dfdf94924f4c60d555a613c5f57
2020-06-21 19:12:38 -07:00

128 lines
5.2 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
namespace tensorflow {
// Abstract interface to a context.
//
// A context is responsible for creating key objects such as Tensors,
// TensorHandles & Operations.
class ImmediateExecutionContext : public AbstractContext {
public:
static constexpr AbstractContextKind kKind = kImmediateExecution;
// Optimized scalar creation functions
virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0;
virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0;
virtual AbstractTensorInterface* CreateInt32Scalar(int32 value) = 0;
virtual AbstractTensorInterface* CreateFloatScalar(float value) = 0;
virtual AbstractTensorInterface* CreateDoubleScalar(double value) = 0;
virtual AbstractTensorInterface* CreateHalfScalar(Eigen::half value) = 0;
virtual AbstractTensorInterface* CreateStringScalar(tstring value) = 0;
virtual AbstractTensorInterface* CreateComplex128Scalar(complex128 value) = 0;
virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0;
// Tensor creation functions
virtual AbstractTensorInterface* CreateTensor(
DataType dtype, absl::Span<const int64> dim_sizes) = 0;
typedef void (*MemoryReleaser)(void* data, size_t len, void* arg);
// Create a tensor instance from the given data buffer and description.
// `memory_releaser` will be called on destruction, and it's responsible for
// cleaning up the underlying buffer. `convert_string` indicates whether it
// has to handle tstring conversion. Expected to be removed once tstring
// migration is done.
virtual AbstractTensorInterface* CreateTensor(DataType dtype,
const int64_t* dims,
int num_dims, void* data,
size_t len, bool convert_string,
MemoryReleaser memory_releaser,
void* memory_releaser_arg) = 0;
// Create a handle to wrap and manage a Tensor
virtual ImmediateExecutionTensorHandle* CreateLocalHandle(
AbstractTensorInterface* t) = 0;
// Copy the handle to another device.
virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice(
ImmediateExecutionTensorHandle* handle, const char* device_name,
Status* status) = 0;
// Create an operation to perform op execution
ImmediateExecutionOperation* CreateOperation() override = 0;
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
// Runtime. This is necessary to decouple runtime-dependent
// code that is layered on top of the runtime.
virtual bool UsesTFRT() = 0;
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
virtual void ClearCachesAndThreadExecutors() = 0;
// Initialize the step resource container for a training step. This is used
// in current TF runtime. For tfrt, it is used by fallback op handler.
virtual void StartStep() = 0;
// Destroy the step resource container for a training step.
virtual void EndStep() = 0;
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
// Add a function (serialized FunctionDef protocol buffer) so that it can
// be executed as an op. Return error if the function with the same name
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
protected:
ImmediateExecutionContext() : AbstractContext(kKind) {}
~ImmediateExecutionContext() override {}
};
namespace internal {
struct ImmediateExecutionContextDeleter {
void operator()(ImmediateExecutionContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using ImmediateContextPtr =
std::unique_ptr<ImmediateExecutionContext,
internal::ImmediateExecutionContextDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_