Throws an error if the shape of the overall tensor is queried for now. The plumbing required to make the shape information look like not-fully-defined-shape graph tensors looks very shallow if we want to go that route. This means that querying the shape of a parallel tensor is now a blocking operation (and needs a status return) rather than creation itself blocking. PiperOrigin-RevId: 351907155 Change-Id: I2610613efd4bb6aafa44fc78ee53824fb6020b6a
1160 lines
42 KiB
C++
1160 lines
42 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/c/eager/c_api.h"
|
|
|
|
#include <algorithm>
|
|
#include <cstddef>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "absl/algorithm/container.h"
|
|
#include "absl/memory/memory.h"
|
|
#include "tensorflow/c/c_api.h"
|
|
#include "tensorflow/c/c_api_internal.h"
|
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
|
#include "tensorflow/c/eager/c_api_internal.h"
|
|
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
|
#include "tensorflow/c/eager/tfe_context_internal.h"
|
|
#include "tensorflow/c/eager/tfe_op_internal.h"
|
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
|
#include "tensorflow/c/tf_tensor_internal.h"
|
|
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
|
#include "tensorflow/core/common_runtime/device.h"
|
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
|
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
|
#include "tensorflow/core/common_runtime/eager/execute.h"
|
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
|
#include "tensorflow/core/common_runtime/function.h"
|
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
|
#include "tensorflow/core/framework/function.h"
|
|
#include "tensorflow/core/framework/node_def_util.h"
|
|
#include "tensorflow/core/framework/rendezvous.h"
|
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
|
#include "tensorflow/core/framework/types.h"
|
|
#include "tensorflow/core/platform/casts.h"
|
|
#include "tensorflow/core/platform/errors.h"
|
|
#include "tensorflow/core/platform/platform.h"
|
|
#include "tensorflow/core/platform/status.h"
|
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
|
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
|
#include "tensorflow/core/public/version.h"
|
|
|
|
// "tensorflow/core/platform/platform.h" must be included first before using
|
|
// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc.
|
|
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
|
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
|
#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h"
|
|
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
|
|
|
#if !defined(IS_MOBILE_PLATFORM)
|
|
#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h"
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
|
|
using tensorflow::string;
|
|
|
|
namespace {
|
|
|
|
string DeviceName(const tensorflow::Device* d) {
|
|
return (d == nullptr) ? "cpu:0" : d->name();
|
|
}
|
|
|
|
// Annotate eager runtime construction context to the given `function_def` as
|
|
// an attribute.
|
|
void AnnotateEagerRuntimeConstructionContext(
|
|
tensorflow::FunctionDef& function_def) {
|
|
tensorflow::AttrValue value;
|
|
SetAttrValue("kEagerRuntime", &value);
|
|
(*function_def.mutable_attr())["_construction_context"] = value;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
extern "C" {
|
|
|
|
TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
|
|
|
|
void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
|
|
size_t proto_len, TF_Status* status) {
|
|
TF_SetConfig(&options->session_options, proto, proto_len, status);
|
|
}
|
|
|
|
void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
|
|
unsigned char enable) {
|
|
options->async = enable;
|
|
}
|
|
|
|
void TFE_ContextOptionsSetDevicePlacementPolicy(
|
|
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
|
|
options->device_placement_policy = policy;
|
|
}
|
|
|
|
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
|
|
|
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
|
if (opts->use_tfrt) {
|
|
#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE)
|
|
tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface(
|
|
opts->session_options.options,
|
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
|
opts->device_placement_policy),
|
|
opts->async);
|
|
#if !defined(IS_MOBILE_PLATFORM)
|
|
tfrt_context->SetDistributedManager(
|
|
tfrt::tf::CreateDistributedManagerContext(
|
|
tfrt_context->GetCoreRuntime()->GetHostContext()));
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
return tensorflow::wrap(tfrt_context);
|
|
#else
|
|
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
|
return nullptr;
|
|
#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE
|
|
}
|
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
|
status->status = tensorflow::DeviceFactory::AddDevices(
|
|
opts->session_options.options, "/job:localhost/replica:0/task:0",
|
|
&devices);
|
|
if (!status->status.ok()) return nullptr;
|
|
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
|
new tensorflow::StaticDeviceMgr(std::move(devices)));
|
|
|
|
tensorflow::Rendezvous* r =
|
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
|
tensorflow::EagerContext* eager_context = new tensorflow::EagerContext(
|
|
opts->session_options.options,
|
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
|
opts->device_placement_policy),
|
|
opts->async, device_mgr.release(),
|
|
/*device_mgr_owned*/ true, r);
|
|
#if !defined(IS_MOBILE_PLATFORM)
|
|
eager_context->SetDistributedManager(
|
|
std::make_unique<tensorflow::EagerContextDistributedManager>(
|
|
eager_context));
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
return tensorflow::wrap(eager_context);
|
|
}
|
|
|
|
void TFE_DeleteContext(TFE_Context* ctx) {
|
|
if (ctx == nullptr) {
|
|
return;
|
|
}
|
|
|
|
// ctx->RefCountIsOne() should be true here.
|
|
tensorflow::unwrap(ctx)->Release();
|
|
}
|
|
|
|
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
|
TF_DeviceList* l = new TF_DeviceList;
|
|
tensorflow::unwrap(ctx)->ListDevices(&l->response);
|
|
return l;
|
|
}
|
|
|
|
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
|
tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
|
|
}
|
|
|
|
// Set server_def on the context, possibly updating it.
|
|
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
|
int keep_alive_secs,
|
|
const void* proto,
|
|
size_t proto_len,
|
|
TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"TFE_ContextSetServerDef not supported on mobile");
|
|
#else // !defined(IS_MOBILE_PLATFORM)
|
|
tensorflow::ServerDef server_def;
|
|
if (!server_def.ParseFromArray(proto, proto_len)) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Invalid tensorflow.ServerDef protocol buffer");
|
|
return;
|
|
}
|
|
status->status =
|
|
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
|
|
server_def, /*reset_context=*/true, keep_alive_secs);
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
|
int keep_alive_secs,
|
|
const void* proto,
|
|
size_t proto_len,
|
|
TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"TFE_ContextSetServerDef not supported on mobile");
|
|
#else // !defined(IS_MOBILE_PLATFORM)
|
|
tensorflow::ServerDef server_def;
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
if (!server_def.ParseFromArray(proto, proto_len)) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Invalid tensorflow.ServerDef protocol buffer");
|
|
return;
|
|
} else if (context->GetContextId() ==
|
|
tensorflow::EagerContext::kInvalidContextId) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Trying to update a context with invalid context id.");
|
|
}
|
|
status->status =
|
|
tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef(
|
|
server_def, /*reset_context=*/false, keep_alive_secs);
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|
const char* worker_name,
|
|
TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"TFE_ContextSetServerDef not supported on mobile");
|
|
return false;
|
|
#else // !defined(IS_MOBILE_PLATFORM)
|
|
bool is_alive;
|
|
status->status =
|
|
tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive(
|
|
worker_name, &is_alive);
|
|
return is_alive;
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
|
TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM)
|
|
status->status = tensorflow::Status::OK();
|
|
#else // !defined(IS_MOBILE_PLATFORM)
|
|
status->status = tensorflow::unwrap(ctx)->AsyncWait();
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
}
|
|
|
|
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
|
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
|
tensorflow::unwrap(ctx)->SetThreadLocalDevicePlacementPolicy(
|
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
|
}
|
|
|
|
// Note: this function looks up a thread local policy. So it should be called in
|
|
// the appropriate client thread. In particular, in async mode, it may not be
|
|
// safe to call this function from the async EagerExecutor threads.
|
|
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
|
TFE_Context* ctx) {
|
|
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
|
tensorflow::unwrap(ctx)->GetDevicePlacementPolicy());
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
|
tensorflow::Tensor tensor;
|
|
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
|
if (!status->status.ok()) return nullptr;
|
|
|
|
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
|
}
|
|
|
|
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
|
if (h == nullptr) return;
|
|
|
|
tensorflow::profiler::TraceMe activity(
|
|
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
|
|
if (h) {
|
|
tensorflow::unwrap(h)->Release();
|
|
}
|
|
}
|
|
|
|
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
|
return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
|
|
}
|
|
|
|
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return -1;
|
|
}
|
|
|
|
int num_dims = -1;
|
|
status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
|
|
return num_dims;
|
|
}
|
|
|
|
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return -1;
|
|
}
|
|
|
|
tensorflow::int64 num_elements = -1;
|
|
status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
|
|
return num_elements;
|
|
}
|
|
|
|
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
|
TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return -1;
|
|
}
|
|
|
|
tensorflow::int64 dim = -1;
|
|
status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
|
|
return dim;
|
|
}
|
|
|
|
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
return tensorflow::unwrap(h)->DeviceName(&status->status);
|
|
}
|
|
|
|
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
|
TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
|
TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
|
|
return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
|
|
}
|
|
|
|
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
|
|
tensorflow::AbstractTensorInterface* t =
|
|
tensorflow::unwrap(h)->Resolve(&status->status);
|
|
if (t == nullptr) {
|
|
return nullptr;
|
|
}
|
|
|
|
return new TF_Tensor{t};
|
|
}
|
|
|
|
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
|
|
tensorflow::unwrap(h);
|
|
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
|
|
if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
|
|
return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
|
|
unwrapped_handle)
|
|
->DevicePointer();
|
|
}
|
|
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
|
|
if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
tensorflow::TensorHandle* handle =
|
|
tensorflow::TensorHandleFromInterface(unwrapped_handle);
|
|
|
|
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"TFE_TensorHandleDevicePointer may not be called on a ",
|
|
handle->TypeString(), " tensor handle.");
|
|
return nullptr;
|
|
}
|
|
tensorflow::Device* device(handle->device());
|
|
if (device != nullptr) {
|
|
status->status = device->Sync();
|
|
if (!status->status.ok()) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
const tensorflow::Tensor* tensor;
|
|
status->status = handle->Tensor(&tensor);
|
|
if (!status->status.ok()) {
|
|
return nullptr;
|
|
}
|
|
return const_cast<void*>(
|
|
static_cast<const void*>(tensor->tensor_data().data()));
|
|
}
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
class CustomDeviceAPI : public tensorflow::CustomDevice {
|
|
public:
|
|
CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
|
|
string name)
|
|
: context_(context), device_(device), info_(info), name_(name) {}
|
|
|
|
~CustomDeviceAPI() override { device_.delete_device(info_); }
|
|
|
|
const string& name() override { return name_; }
|
|
|
|
tensorflow::Status CopyTensorToDevice(
|
|
ImmediateExecutionTensorHandle* handle,
|
|
ImmediateExecutionTensorHandle** result) override {
|
|
handle->Ref();
|
|
TF_Status status;
|
|
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
|
|
context_, tensorflow::wrap(handle), &status, info_);
|
|
handle->Release();
|
|
if (!status.status.ok()) return status.status;
|
|
*result = tensorflow::unwrap(result_handle);
|
|
(*result)->Ref();
|
|
TFE_DeleteTensorHandle(result_handle);
|
|
return status.status;
|
|
}
|
|
|
|
tensorflow::Status CopyTensorFromDevice(
|
|
ImmediateExecutionTensorHandle* handle,
|
|
const tensorflow::string& target_device_name,
|
|
ImmediateExecutionTensorHandle** result) override {
|
|
TF_Status status;
|
|
handle->Ref();
|
|
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
|
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
|
|
info_);
|
|
handle->Release();
|
|
if (!status.status.ok()) return status.status;
|
|
*result = tensorflow::unwrap(result_handle);
|
|
(*result)->Ref();
|
|
TFE_DeleteTensorHandle(result_handle);
|
|
return status.status;
|
|
}
|
|
|
|
tensorflow::Status Execute(const ImmediateExecutionOperation* op,
|
|
ImmediateExecutionTensorHandle** retvals,
|
|
int* num_retvals) override {
|
|
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
|
TF_Status status;
|
|
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
|
|
info_);
|
|
if (status.status.ok()) {
|
|
for (int i = 0; i < *num_retvals; ++i) {
|
|
retvals[i] = tensorflow::unwrap(outputs[i]);
|
|
retvals[i]->Ref();
|
|
TFE_DeleteTensorHandle(outputs[i]);
|
|
}
|
|
}
|
|
return status.status;
|
|
}
|
|
|
|
tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
|
|
ImmediateExecutionTensorHandle** result) override {
|
|
TF_Status status;
|
|
*result = tensorflow::unwrap(device_.pack(context_,
|
|
tensorflow::wrap(handles.data()),
|
|
handles.size(), &status, info_));
|
|
return status.status;
|
|
}
|
|
|
|
private:
|
|
TFE_Context* context_;
|
|
TFE_CustomDevice device_;
|
|
void* info_;
|
|
string name_;
|
|
};
|
|
|
|
// An adapter which wraps the shape/data produced by C custom devices and uses
|
|
// it to implement custom device methods.
|
|
class CAPICustomDeviceTensorHandle
|
|
: public tensorflow::CustomDeviceTensorHandle {
|
|
public:
|
|
using NumDimsCallback = std::function<int(TF_Status* status)>;
|
|
using DimCallback = std::function<int64_t(int dim_index, TF_Status* status)>;
|
|
using DeallocatorCallback = std::function<void()>;
|
|
|
|
CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
|
|
tensorflow::CustomDevice* device,
|
|
tensorflow::DataType dtype, void* data,
|
|
NumDimsCallback num_dims_callback,
|
|
DimCallback dim_callback,
|
|
DeallocatorCallback deallocator)
|
|
: tensorflow::CustomDeviceTensorHandle(context, device, dtype),
|
|
data_(data),
|
|
num_dims_callback_(num_dims_callback),
|
|
dim_callback_(dim_callback),
|
|
deallocator_(deallocator) {}
|
|
|
|
~CAPICustomDeviceTensorHandle() override { deallocator_(); }
|
|
void* DevicePointer() const override { return data_; }
|
|
Status NumDims(int* num_dims) const override {
|
|
TF_Status s;
|
|
*num_dims = num_dims_callback_(&s);
|
|
return s.status;
|
|
}
|
|
Status Dim(int dim_index, int64* dim) const override {
|
|
TF_Status s;
|
|
*dim = dim_callback_(dim_index, &s);
|
|
return s.status;
|
|
}
|
|
|
|
private:
|
|
void* const data_;
|
|
NumDimsCallback num_dims_callback_;
|
|
DimCallback dim_callback_;
|
|
DeallocatorCallback deallocator_;
|
|
};
|
|
|
|
} // namespace
|
|
} // namespace tensorflow
|
|
|
|
TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
|
|
TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data,
|
|
int (*num_dims_callback)(void* data, void* arg, TF_Status* status),
|
|
int64_t (*dim_callback)(void* data, int dim_index, void* arg,
|
|
TF_Status* status),
|
|
void (*deallocator)(void* data, void* arg), void* arg, TF_Status* status) {
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
tensorflow::CustomDevice* device = nullptr;
|
|
if (!context->FindCustomDeviceFromName(device_name, &device)) {
|
|
deallocator(data, arg);
|
|
status->status =
|
|
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
|
|
return nullptr;
|
|
}
|
|
return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
|
|
context, device, *reinterpret_cast<tensorflow::DataType*>(&dtype), data,
|
|
/*num_dims_callback=*/
|
|
[num_dims_callback, data, arg](TF_Status* status) {
|
|
return num_dims_callback(data, arg, status);
|
|
},
|
|
/*dim_callback=*/
|
|
[dim_callback, data, arg](int dim_index, TF_Status* status) {
|
|
return dim_callback(data, dim_index, arg, status);
|
|
},
|
|
/*deallocator=*/[deallocator, data, arg]() { deallocator(data, arg); }));
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|
TFE_Context* ctx, const char* device_name, TF_DataType dtype,
|
|
const int64_t* dims, int num_dims, void* data, size_t len,
|
|
void (*deallocator)(void* data, size_t len, void* arg),
|
|
void* deallocator_arg, TF_Status* status) {
|
|
tensorflow::Device* device = nullptr;
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
status->status = context->FindDeviceFromName(device_name, &device);
|
|
tensorflow::CustomDevice* custom_device = nullptr;
|
|
if (!status->status.ok()) {
|
|
if (!context->FindCustomDeviceFromName(device_name, &custom_device)) {
|
|
deallocator(data, len, deallocator_arg);
|
|
status->status =
|
|
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
|
|
return nullptr;
|
|
} else {
|
|
status->status = tensorflow::Status::OK();
|
|
}
|
|
}
|
|
std::vector<tensorflow::int64> dimvec(num_dims);
|
|
for (int i = 0; i < num_dims; ++i) {
|
|
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
|
}
|
|
if (custom_device != nullptr) {
|
|
return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
|
|
context, custom_device,
|
|
*reinterpret_cast<tensorflow::DataType*>(&dtype), data,
|
|
/*num_dims_callback=*/
|
|
[num_dims](TF_Status* status) { return num_dims; },
|
|
/*dim_callback=*/
|
|
[dimvec](int dim_index, TF_Status* status) {
|
|
return dimvec[dim_index];
|
|
},
|
|
/*deallocator=*/
|
|
[data, len, deallocator, deallocator_arg]() {
|
|
deallocator(data, len, deallocator_arg);
|
|
}));
|
|
}
|
|
|
|
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
|
// the device?
|
|
TF_ManagedBuffer* buf =
|
|
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
|
/*owns_memory=*/false);
|
|
|
|
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
|
tensorflow::TensorShape(dimvec), buf);
|
|
buf->Unref();
|
|
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
|
std::move(t), device, device, context));
|
|
}
|
|
|
|
// This function will block till the operation that produces `h` has
|
|
// completed. This is only valid on local TFE_TensorHandles. Returns the size in
|
|
// bytes of the memory pointed to by the device pointer returned above.
|
|
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
|
TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return 0;
|
|
}
|
|
tensorflow::TensorHandle* handle =
|
|
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
|
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"TFE_TensorHandleDeviceMemorySize may not be called on a ",
|
|
handle->TypeString(), " tensor handle.");
|
|
return 0;
|
|
}
|
|
const tensorflow::Tensor* tensor;
|
|
status->status = handle->Tensor(&tensor);
|
|
if (!status->status.ok()) {
|
|
return 0;
|
|
}
|
|
return tensor->TotalBytes();
|
|
}
|
|
|
|
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
|
TF_Status* status) {
|
|
tensorflow::ImmediateExecutionOperation* new_op =
|
|
tensorflow::unwrap(ctx)->CreateOperation();
|
|
status->status = new_op->Reset(op_or_function_name, nullptr);
|
|
if (!status->status.ok()) {
|
|
new_op->Release();
|
|
new_op = nullptr;
|
|
}
|
|
return tensorflow::wrap(new_op);
|
|
}
|
|
|
|
void TFE_DeleteOp(TFE_Op* op) {
|
|
if (op == nullptr) {
|
|
return;
|
|
}
|
|
|
|
tensorflow::unwrap(op)->Release();
|
|
}
|
|
|
|
const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
|
|
return tensorflow::unwrap(op)->Name().c_str();
|
|
}
|
|
|
|
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
|
|
return tensorflow::wrap(
|
|
&(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
|
|
}
|
|
|
|
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
|
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
|
|
}
|
|
|
|
const char* TFE_OpGetDevice(const TFE_Op* op, TF_Status* status) {
|
|
return tensorflow::unwrap(op)->DeviceName().c_str();
|
|
}
|
|
|
|
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
|
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
|
|
}
|
|
|
|
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
|
TF_Status* status) {
|
|
status->status = tensorflow::unwrap(op)->AddInputList(
|
|
{reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
|
tensorflow::unwrap(inputs)),
|
|
static_cast<size_t>(num_inputs)});
|
|
}
|
|
|
|
extern int TFE_OpGetFlatInputCount(const TFE_Op* op, TF_Status* status) {
|
|
return tensorflow::unwrap(op)->GetInputs().size();
|
|
}
|
|
|
|
extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, int index,
|
|
TF_Status* status) {
|
|
return tensorflow::wrap(tensorflow::unwrap(op)->GetInputs()[index]);
|
|
}
|
|
|
|
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
|
unsigned char* is_list, TF_Status* status) {
|
|
TF_AttrType ret = TF_ATTR_INT;
|
|
const tensorflow::AttrTypeMap* attr_types_;
|
|
bool is_function;
|
|
status->status = tensorflow::AttrTypeMapForOp(
|
|
tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
|
|
if (!status->status.ok()) {
|
|
return ret;
|
|
}
|
|
status->status =
|
|
tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
|
|
return ret;
|
|
}
|
|
|
|
TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
|
|
const char* op_or_function_name,
|
|
const char* attr_name, unsigned char* is_list,
|
|
TF_Status* status) {
|
|
TF_AttrType ret;
|
|
TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
|
|
if (status->status.ok()) {
|
|
ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
|
|
} else {
|
|
ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
|
|
}
|
|
TFE_DeleteOp(op);
|
|
return ret;
|
|
}
|
|
|
|
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
|
size_t length) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrString(
|
|
attr_name, static_cast<const char*>(value), length);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
|
|
(value == 0) ? false : true);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrType(
|
|
attr_name, static_cast<tensorflow::DataType>(value));
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
|
const int num_dims, TF_Status* out_status) {
|
|
out_status->status =
|
|
tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
|
|
}
|
|
|
|
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
|
const TFE_Op* value) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrFunction(
|
|
attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
|
const char* data, size_t length) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
|
TF_Status* status) {
|
|
tensorflow::Tensor t;
|
|
status->status = TF_TensorToTensor(tensor, &t);
|
|
tensorflow::TensorInterface interface(t);
|
|
status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
|
|
}
|
|
|
|
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
|
const void* const* values, const size_t* lengths,
|
|
int num_values) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
|
|
num_values);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
|
const float* values, int num_values) {
|
|
auto s =
|
|
tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
|
const int64_t* values, int num_values) {
|
|
auto s =
|
|
tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
|
const TF_DataType* values, int num_values) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrTypeList(
|
|
attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
|
|
num_values);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
|
const unsigned char* values, int num_values) {
|
|
auto s =
|
|
tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
|
const int64_t** dims, const int* num_dims,
|
|
int num_values, TF_Status* out_status) {
|
|
out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
|
|
attr_name, dims, num_dims, num_values);
|
|
}
|
|
|
|
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
|
const TFE_Op** value, int num_values) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
|
|
attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
|
|
tensorflow::unwrap(value)),
|
|
static_cast<size_t>(num_values)});
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
|
|
const void* proto, size_t proto_len,
|
|
TF_Status* status) {
|
|
tensorflow::AttrValue attr_value;
|
|
if (!attr_value.ParseFromArray(proto, proto_len)) {
|
|
status->status =
|
|
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
|
|
return;
|
|
}
|
|
if (op == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Got a null or uninitialized `op` argument");
|
|
return;
|
|
}
|
|
tensorflow::EagerOperation* operation =
|
|
OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
|
|
operation->MutableAttrs()->Set(attr_name, attr_value);
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
|
const char* input_name,
|
|
TF_Status* status) {
|
|
int ret = -1;
|
|
status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
|
|
return ret;
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
|
const char* output_name,
|
|
TF_Status* status) {
|
|
int ret = -1;
|
|
status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
|
|
return ret;
|
|
}
|
|
|
|
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
|
TF_Status* status) {
|
|
status->status = tensorflow::unwrap(op)->Execute(
|
|
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
|
tensorflow::unwrap(retvals)),
|
|
*num_retvals),
|
|
num_retvals);
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
|
TFE_Context* ctx,
|
|
const char* device_name,
|
|
TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
|
|
auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice(
|
|
tensorflow::unwrap(h), device_name, &status->status);
|
|
if (status->status.ok()) {
|
|
return tensorflow::wrap(result);
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
|
const char* serialized_function_def, size_t size,
|
|
TF_Status* status) {
|
|
tensorflow::FunctionDef function_def;
|
|
if (!function_def.ParseFromArray(serialized_function_def, size)) {
|
|
status->status =
|
|
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
|
|
return;
|
|
}
|
|
|
|
AnnotateEagerRuntimeConstructionContext(function_def);
|
|
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
|
|
}
|
|
|
|
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
|
TF_Status* status) {
|
|
AnnotateEagerRuntimeConstructionContext(function->fdef);
|
|
status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces(
|
|
function->fdef, function->stack_traces);
|
|
}
|
|
|
|
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
|
TF_Status* status) {
|
|
status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
|
|
}
|
|
|
|
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
|
return tensorflow::unwrap(ctx)->FindFunctionDef(name) != nullptr;
|
|
}
|
|
|
|
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
|
|
}
|
|
|
|
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
|
tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
|
|
}
|
|
|
|
} // extern "C"
|
|
|
|
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
|
TF_Status* status) {
|
|
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
|
|
}
|
|
|
|
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
|
TF_Status* status) {
|
|
auto* context = tensorflow::unwrap(ctx);
|
|
status->status = context->AsyncWait();
|
|
if (!status->status.ok()) return;
|
|
auto run_metadata = context->ExportRunMetadata();
|
|
status->status = MessageToBuffer(*run_metadata, buf);
|
|
}
|
|
|
|
namespace {
|
|
TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
|
TF_Status* status) {
|
|
TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
|
|
for (const auto& attr : func.attr()) {
|
|
if (!status->status.ok()) return nullptr;
|
|
SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
|
|
if (!status->status.ok()) return nullptr;
|
|
}
|
|
return func_op;
|
|
}
|
|
} // namespace
|
|
|
|
void TFE_ContextStartStep(TFE_Context* ctx) {
|
|
tensorflow::unwrap(ctx)->StartStep();
|
|
}
|
|
|
|
void TFE_ContextEndStep(TFE_Context* ctx) {
|
|
tensorflow::unwrap(ctx)->EndStep();
|
|
}
|
|
|
|
const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) {
|
|
return tensorflow::wrap(
|
|
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
|
|
}
|
|
|
|
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
|
tensorflow::EagerOperation* operation =
|
|
OperationFromInterface(tensorflow::unwrap(op));
|
|
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
|
destination->CopyAttributes(*tensorflow::unwrap(attrs));
|
|
}
|
|
|
|
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
|
TF_Status* status) {
|
|
tensorflow::NameAttrList name_and_attrs;
|
|
tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr());
|
|
name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name());
|
|
status->status = MessageToBuffer(name_and_attrs, buf);
|
|
}
|
|
|
|
namespace tensorflow {
|
|
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
|
const tensorflow::AttrValue& default_value,
|
|
const char* attr_name, TF_Status* status) {
|
|
switch (default_value.value_case()) {
|
|
case tensorflow::AttrValue::kS: {
|
|
const string& v = default_value.s();
|
|
TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
|
|
break;
|
|
}
|
|
case tensorflow::AttrValue::kI:
|
|
TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
|
|
break;
|
|
case tensorflow::AttrValue::kF:
|
|
TFE_OpSetAttrFloat(op, attr_name, default_value.f());
|
|
break;
|
|
case tensorflow::AttrValue::kB:
|
|
TFE_OpSetAttrBool(op, attr_name, default_value.b());
|
|
break;
|
|
case tensorflow::AttrValue::kType:
|
|
TFE_OpSetAttrType(op, attr_name,
|
|
static_cast<TF_DataType>(default_value.type()));
|
|
break;
|
|
case tensorflow::AttrValue::kShape: {
|
|
const auto& tensor_shape = default_value.shape();
|
|
if (tensor_shape.unknown_rank()) {
|
|
TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
|
|
} else {
|
|
const auto num_dims = tensor_shape.dim_size();
|
|
std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
|
|
for (int i = 0; i < num_dims; ++i) {
|
|
dims[i] = tensor_shape.dim(i).size();
|
|
}
|
|
TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
|
|
}
|
|
} break;
|
|
case tensorflow::AttrValue::kFunc: {
|
|
const auto func_op = GetFunc(ctx, default_value.func(), status);
|
|
if (!status->status.ok()) return;
|
|
// TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
|
|
// require TFE_Op* and just convert it internally a NameAttrValue, so
|
|
// consider adding an overload to the C API to make this case easier.
|
|
TFE_OpSetAttrFunction(op, attr_name, func_op);
|
|
TFE_DeleteOp(func_op);
|
|
} break;
|
|
case tensorflow::AttrValue::kList: {
|
|
// String
|
|
if (const int s_size = default_value.list().s_size()) {
|
|
absl::InlinedVector<const void*, 4> values_vector;
|
|
absl::InlinedVector<size_t, 4> lengths_vector;
|
|
for (int i = 0; i < s_size; ++i) {
|
|
const string& v = default_value.list().s(i);
|
|
values_vector.push_back(v.data());
|
|
lengths_vector.push_back(v.size());
|
|
}
|
|
TFE_OpSetAttrStringList(op, attr_name, values_vector.data(),
|
|
lengths_vector.data(), s_size);
|
|
}
|
|
|
|
// Int
|
|
if (const int i_size = default_value.list().i_size()) {
|
|
absl::InlinedVector<int64_t, 4> i_vector;
|
|
for (int i = 0; i < i_size; ++i) {
|
|
i_vector.push_back(default_value.list().i(i));
|
|
}
|
|
TFE_OpSetAttrIntList(op, attr_name, i_vector.data(), i_size);
|
|
}
|
|
// Float
|
|
if (const int f_size = default_value.list().f_size()) {
|
|
absl::InlinedVector<float, 4> f_vector;
|
|
for (int i = 0; i < f_size; ++i) {
|
|
f_vector.push_back(default_value.list().f(i));
|
|
}
|
|
TFE_OpSetAttrFloatList(op, attr_name, f_vector.data(), f_size);
|
|
}
|
|
// Bool
|
|
if (const int b_size = default_value.list().b_size()) {
|
|
absl::InlinedVector<unsigned char, 4> b_vector;
|
|
for (int i = 0; i < b_size; i++) {
|
|
b_vector.push_back(default_value.list().b(i));
|
|
}
|
|
TFE_OpSetAttrBoolList(op, attr_name, b_vector.data(), b_size);
|
|
}
|
|
// Type
|
|
if (const int type_size = default_value.list().type_size()) {
|
|
absl::InlinedVector<unsigned int, 4> type_vector;
|
|
for (int i = 0; i < type_size; ++i) {
|
|
type_vector.push_back(default_value.list().type(i));
|
|
}
|
|
TFE_OpSetAttrTypeList(
|
|
op, attr_name,
|
|
reinterpret_cast<const TF_DataType*>(type_vector.data()),
|
|
type_size);
|
|
}
|
|
|
|
// Rest are not supported.
|
|
if (default_value.list().shape_size() > 0 ||
|
|
default_value.list().func_size() > 0 ||
|
|
default_value.list().tensor_size() > 0) {
|
|
TF_SetStatus(
|
|
status, TF_UNIMPLEMENTED,
|
|
tensorflow::strings::StrCat("Unable to get setfor default value: ",
|
|
default_value.DebugString())
|
|
.data());
|
|
}
|
|
} break;
|
|
case tensorflow::AttrValue::kTensor:
|
|
TF_FALLTHROUGH_INTENDED;
|
|
case tensorflow::AttrValue::kPlaceholder:
|
|
TF_FALLTHROUGH_INTENDED;
|
|
case tensorflow::AttrValue::VALUE_NOT_SET:
|
|
TF_SetStatus(
|
|
status, TF_UNIMPLEMENTED,
|
|
tensorflow::strings::StrCat("Unable to get setfor default value: ",
|
|
default_value.DebugString())
|
|
.data());
|
|
}
|
|
}
|
|
} // namespace tensorflow
|
|
|
|
namespace {
|
|
TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context,
|
|
TFE_TensorHandle** handles,
|
|
int num_handles, TF_Status* status,
|
|
void* device_info) {
|
|
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
|
"This custom device does not support packing tensors.");
|
|
return nullptr;
|
|
}
|
|
} // namespace
|
|
|
|
extern "C" {
|
|
|
|
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
|
const char* device_name, void* device_info,
|
|
TF_Status* status) {
|
|
// Fill in default values for optional functionality.
|
|
if (device.pack == nullptr) {
|
|
device.pack = &DefaultCustomDevicePack;
|
|
}
|
|
auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
|
|
ctx, device, device_info, device_name);
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
status->status =
|
|
context->RegisterCustomDevice(device_name, std::move(custom_device));
|
|
}
|
|
|
|
} // extern "C"
|