1656 lines
65 KiB
C++
1656 lines
65 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/c/eager/c_api.h"
|
|
|
|
#include <algorithm>
|
|
#include <cstddef>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
|
|
|
// clang-format off
|
|
#include "tensorflow/core/platform/platform.h"
|
|
// clang-format on
|
|
|
|
#include "absl/algorithm/container.h"
|
|
#include "absl/memory/memory.h"
|
|
#include "tensorflow/c/c_api.h"
|
|
#include "tensorflow/c/c_api_internal.h"
|
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
|
#include "tensorflow/c/eager/c_api_internal.h"
|
|
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
|
#include "tensorflow/c/eager/tfe_context_internal.h"
|
|
#include "tensorflow/c/eager/tfe_op_internal.h"
|
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
|
#include "tensorflow/c/tf_tensor_internal.h"
|
|
#ifdef PLATFORM_GOOGLE
|
|
#include "tensorflow/core/tfrt/eager/c_api_tfrt.h"
|
|
#endif
|
|
#include "tensorflow/core/common_runtime/device.h"
|
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
|
#include "tensorflow/core/framework/function.h"
|
|
#include "tensorflow/core/platform/errors.h"
|
|
#include "tensorflow/core/platform/status.h"
|
|
#include "tensorflow/core/protobuf/device_filters.pb.h"
|
|
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
|
#include "tensorflow/core/util/device_name_utils.h"
|
|
#ifdef TENSORFLOW_EAGER_USE_XLA
|
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
|
#endif // TENSORFLOW_EAGER_USE_XLA
|
|
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
|
#include "tensorflow/core/common_runtime/device_set.h"
|
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
|
#include "tensorflow/core/common_runtime/eager/execute.h"
|
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
|
#include "tensorflow/core/common_runtime/function.h"
|
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
|
#if !defined(IS_MOBILE_PLATFORM)
|
|
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
|
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
|
|
#include "tensorflow/core/distributed_runtime/remote_device.h"
|
|
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
|
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
|
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
|
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
|
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
#include "tensorflow/core/framework/node_def_util.h"
|
|
#include "tensorflow/core/framework/rendezvous.h"
|
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
|
#include "tensorflow/core/framework/types.h"
|
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
|
#include "tensorflow/core/platform/blocking_counter.h"
|
|
#include "tensorflow/core/platform/casts.h"
|
|
#include "tensorflow/core/platform/env.h"
|
|
#include "tensorflow/core/platform/mutex.h"
|
|
#include "tensorflow/core/platform/notification.h"
|
|
#include "tensorflow/core/platform/random.h"
|
|
#include "tensorflow/core/platform/refcount.h"
|
|
#include "tensorflow/core/platform/stringpiece.h"
|
|
#include "tensorflow/core/platform/thread_annotations.h"
|
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
|
#include "tensorflow/core/public/version.h"
|
|
|
|
using tensorflow::int64;
|
|
using tensorflow::string;
|
|
|
|
namespace {
|
|
|
|
string DeviceName(const tensorflow::Device* d) {
|
|
return (d == nullptr) ? "cpu:0" : d->name();
|
|
}
|
|
|
|
#if !defined(IS_MOBILE_PLATFORM)
|
|
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
|
|
const tensorflow::ServerDef& server_def) {
|
|
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
|
|
return false;
|
|
}
|
|
return server_def.default_session_config().SerializeAsString() ==
|
|
context->session_options().config.SerializeAsString();
|
|
}
|
|
|
|
tensorflow::Status AddRemoteDevicesToMgr(
|
|
const std::vector<string>& added_remote_workers,
|
|
tensorflow::WorkerCacheInterface* worker_cache,
|
|
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
|
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
|
|
tensorflow::mutex remote_devices_mu;
|
|
int num_added_workers = added_remote_workers.size();
|
|
tensorflow::BlockingCounter counter(num_added_workers);
|
|
std::vector<tensorflow::Status> statuses(num_added_workers);
|
|
for (int i = 0; i < num_added_workers; i++) {
|
|
tensorflow::NewRemoteDevices(
|
|
tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
|
|
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
|
|
const tensorflow::Status& s,
|
|
std::vector<tensorflow::Device*>* devices) {
|
|
statuses[i] = s;
|
|
if (s.ok()) {
|
|
tensorflow::mutex_lock l(remote_devices_mu);
|
|
for (tensorflow::Device* d : *devices) {
|
|
remote_devices.emplace_back(d);
|
|
}
|
|
}
|
|
counter.DecrementCount();
|
|
});
|
|
}
|
|
counter.Wait();
|
|
for (int i = 0; i < num_added_workers; i++) {
|
|
TF_RETURN_IF_ERROR(statuses[i]);
|
|
}
|
|
|
|
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
tensorflow::Status GetAllRemoteDevices(
|
|
const std::vector<string>& remote_workers,
|
|
tensorflow::WorkerCacheInterface* worker_cache,
|
|
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
|
|
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
|
|
TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
|
|
remote_device_mgr.get()));
|
|
*device_mgr = std::move(remote_device_mgr);
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
tensorflow::Status RemoveRemoteDevicesFromMgr(
|
|
const std::vector<string>& removed_remote_workers,
|
|
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
|
|
const std::vector<tensorflow::Device*> remote_devices =
|
|
(remote_device_mgr->ListDevices());
|
|
std::vector<tensorflow::Device*> devices_to_remove;
|
|
for (tensorflow::Device* d : remote_devices) {
|
|
for (const string& remote_worker : removed_remote_workers) {
|
|
if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
|
|
d->name())) {
|
|
devices_to_remove.emplace_back(d);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
|
|
const string& local_worker,
|
|
std::vector<string>* remote_workers) {
|
|
tensorflow::GrpcServer* grpc_server =
|
|
dynamic_cast<tensorflow::GrpcServer*>(server);
|
|
if (grpc_server == nullptr) {
|
|
return tensorflow::errors::Internal(
|
|
"Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
|
|
}
|
|
grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
|
|
remote_workers->erase(
|
|
std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
|
|
remote_workers->end());
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
void DifferentiateWorkerLists(const std::vector<string>* current_list,
|
|
const std::vector<string>* new_list,
|
|
std::vector<string>* added,
|
|
std::vector<string>* removed,
|
|
std::vector<string>* existing) {
|
|
// Get STL set_difference and set_intersection with one list traversal.
|
|
// Similar to the set_difference library function, the input lists
|
|
// (`current_list` and `new_list`) must be sorted before calling the function.
|
|
added->resize(new_list->size());
|
|
removed->resize(current_list->size());
|
|
existing->resize(current_list->size());
|
|
std::vector<string>::const_iterator curr_it = current_list->begin();
|
|
std::vector<string>::const_iterator new_it = new_list->begin();
|
|
std::vector<string>::iterator added_it = added->begin();
|
|
std::vector<string>::iterator removed_it = removed->begin();
|
|
std::vector<string>::iterator existing_it = existing->begin();
|
|
while (curr_it != current_list->end() && new_it != new_list->end()) {
|
|
if (*curr_it < *new_it) {
|
|
*removed_it++ = *curr_it++;
|
|
} else if (*curr_it > *new_it) {
|
|
*added_it++ = *new_it++;
|
|
} else {
|
|
*existing_it++ = *curr_it++;
|
|
new_it++;
|
|
}
|
|
}
|
|
removed_it = std::copy(curr_it, current_list->end(), removed_it);
|
|
added_it = std::copy(new_it, new_list->end(), added_it);
|
|
added->resize(added_it - added->begin());
|
|
removed->resize(removed_it - removed->begin());
|
|
existing->resize(existing_it - existing->begin());
|
|
}
|
|
|
|
tensorflow::Status GetReplacedFromExistingWorkers(
|
|
const std::vector<string>* existing_workers, tensorflow::uint64 context_id,
|
|
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
|
tensorflow::eager::EagerClientCache* client_cache,
|
|
std::vector<string>* replaced_workers) {
|
|
tensorflow::BlockingCounter counter(existing_workers->size());
|
|
std::vector<tensorflow::Status> statuses(existing_workers->size());
|
|
tensorflow::eager::KeepAliveRequest request;
|
|
request.set_context_id(context_id);
|
|
std::vector<tensorflow::eager::KeepAliveResponse> responses(
|
|
existing_workers->size());
|
|
for (int i = 0; i < existing_workers->size(); i++) {
|
|
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
|
statuses[i] =
|
|
client_cache->GetClient(existing_workers->at(i), &eager_client);
|
|
if (!statuses[i].ok()) {
|
|
counter.DecrementCount();
|
|
continue;
|
|
}
|
|
eager_client->KeepAliveAsync(
|
|
&request, &responses[i],
|
|
[i, &statuses, &counter](const tensorflow::Status& s) {
|
|
statuses[i] = s;
|
|
counter.DecrementCount();
|
|
});
|
|
}
|
|
counter.Wait();
|
|
for (int i = 0; i < existing_workers->size(); i++) {
|
|
// If the RPC fails (indicating that the requested ID doesn't exist on
|
|
// remote), or the returned view ID is not equal to the local one
|
|
// (indicating that the remote worker has a stale view of cluster), treat
|
|
// the worker as replaced.
|
|
if (!statuses[i].ok() ||
|
|
responses[i].context_view_id() != context_view_id) {
|
|
replaced_workers->emplace_back(existing_workers->at(i));
|
|
}
|
|
}
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
tensorflow::Status CreateRemoteContexts(
|
|
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
|
tensorflow::uint64 context_id, tensorflow::uint64 context_view_id,
|
|
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
|
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
|
const bool lazy_copy_remote_function_inputs,
|
|
const tensorflow::eager::CreateContextRequest& base_request) {
|
|
int num_remote_workers = remote_workers.size();
|
|
tensorflow::BlockingCounter counter(num_remote_workers);
|
|
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
|
for (int i = 0; i < num_remote_workers; i++) {
|
|
const string& remote_worker = remote_workers[i];
|
|
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
|
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
|
&parsed_name)) {
|
|
statuses[i] = tensorflow::errors::InvalidArgument(
|
|
"Unable to parse ", remote_worker, " as a device name");
|
|
counter.DecrementCount();
|
|
continue;
|
|
}
|
|
|
|
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
|
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
|
if (eager_client == nullptr) {
|
|
statuses[i] = tensorflow::errors::Internal(
|
|
"Cannot find a client for the given target:", remote_worker);
|
|
}
|
|
if (!statuses[i].ok()) {
|
|
counter.DecrementCount();
|
|
continue;
|
|
}
|
|
|
|
tensorflow::eager::CreateContextRequest request;
|
|
tensorflow::eager::CreateContextResponse* response =
|
|
new tensorflow::eager::CreateContextResponse();
|
|
request.set_context_id(context_id);
|
|
request.set_context_view_id(context_view_id);
|
|
*request.mutable_server_def() = server_def;
|
|
request.mutable_server_def()->set_job_name(parsed_name.job);
|
|
request.mutable_server_def()->set_task_index(parsed_name.task);
|
|
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
|
server_def.default_session_config());
|
|
|
|
std::vector<bool> filtered_device_mask;
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
context->FilterDevicesForRemoteWorkers(
|
|
remote_worker, base_request.cluster_device_attributes(),
|
|
&filtered_device_mask);
|
|
DCHECK_EQ(filtered_device_mask.size(),
|
|
base_request.cluster_device_attributes_size());
|
|
for (int i = 0; i < filtered_device_mask.size(); i++) {
|
|
if (filtered_device_mask[i]) {
|
|
const auto& da = base_request.cluster_device_attributes(i);
|
|
*request.add_cluster_device_attributes() = da;
|
|
}
|
|
}
|
|
request.set_async(async);
|
|
request.set_keep_alive_secs(keep_alive_secs);
|
|
request.set_lazy_copy_remote_function_inputs(
|
|
lazy_copy_remote_function_inputs);
|
|
|
|
eager_client->CreateContextAsync(
|
|
&request, response,
|
|
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
|
statuses[i] = s;
|
|
delete response;
|
|
counter.DecrementCount();
|
|
});
|
|
}
|
|
counter.Wait();
|
|
for (int i = 0; i < num_remote_workers; i++) {
|
|
TF_RETURN_IF_ERROR(statuses[i]);
|
|
}
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
tensorflow::Status UpdateRemoteContexts(
|
|
TFE_Context* ctx, const std::vector<string>& remote_workers,
|
|
const std::vector<string>& added_workers,
|
|
const std::vector<string>& removed_workers, tensorflow::uint64 context_id,
|
|
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
|
|
tensorflow::eager::EagerClientCache* remote_eager_workers,
|
|
const tensorflow::eager::CreateContextRequest& base_request) {
|
|
int num_remote_workers = remote_workers.size();
|
|
tensorflow::BlockingCounter counter(num_remote_workers);
|
|
std::vector<tensorflow::Status> statuses(num_remote_workers);
|
|
|
|
int cluster_device_count = base_request.cluster_device_attributes_size();
|
|
std::unordered_set<string> added_or_removed(added_workers.begin(),
|
|
added_workers.end());
|
|
std::copy(removed_workers.begin(), removed_workers.end(),
|
|
std::inserter(added_or_removed, added_or_removed.end()));
|
|
// Whether each device is in the updated (added or removed) workers
|
|
std::vector<bool> device_added_or_removed(cluster_device_count);
|
|
for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) {
|
|
const auto& da = base_request.cluster_device_attributes().at(i);
|
|
tensorflow::DeviceNameUtils::ParsedName pn;
|
|
tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn);
|
|
string task_name;
|
|
tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name);
|
|
if (added_or_removed.find(task_name) != added_or_removed.end()) {
|
|
device_added_or_removed[i] = true;
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < num_remote_workers; i++) {
|
|
const string& remote_worker = remote_workers[i];
|
|
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
|
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
|
&parsed_name)) {
|
|
statuses[i] = tensorflow::errors::InvalidArgument(
|
|
"Unable to parse ", remote_worker, " as a device name");
|
|
counter.DecrementCount();
|
|
continue;
|
|
}
|
|
|
|
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
|
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
|
|
if (eager_client == nullptr) {
|
|
statuses[i] = tensorflow::errors::Internal(
|
|
"Cannot find a client for the given target:", remote_worker);
|
|
}
|
|
if (!statuses[i].ok()) {
|
|
counter.DecrementCount();
|
|
continue;
|
|
}
|
|
|
|
std::vector<bool> filtered_device_mask;
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
context->FilterDevicesForRemoteWorkers(
|
|
remote_worker, base_request.cluster_device_attributes(),
|
|
&filtered_device_mask);
|
|
DCHECK_EQ(filtered_device_mask.size(), cluster_device_count);
|
|
|
|
// If any of the devices that match the device filters are in the set of
|
|
// added or removed workers, we must send a complete UpdateContextRequest.
|
|
// Otherwise, only send a simple request to increment context view ID.
|
|
std::vector<bool> added_or_removed_filtered_devices(cluster_device_count);
|
|
std::transform(device_added_or_removed.begin(),
|
|
device_added_or_removed.end(), filtered_device_mask.begin(),
|
|
added_or_removed_filtered_devices.begin(),
|
|
std::logical_and<bool>());
|
|
const bool full_update_request =
|
|
std::accumulate(added_or_removed_filtered_devices.begin(),
|
|
added_or_removed_filtered_devices.end(), false,
|
|
std::logical_or<bool>());
|
|
|
|
tensorflow::eager::UpdateContextRequest request;
|
|
auto* response = new tensorflow::eager::UpdateContextResponse();
|
|
request.set_context_id(context_id);
|
|
request.set_context_view_id(context_view_id);
|
|
if (full_update_request) {
|
|
*request.mutable_server_def() = server_def;
|
|
request.mutable_server_def()->set_job_name(parsed_name.job);
|
|
request.mutable_server_def()->set_task_index(parsed_name.task);
|
|
request.mutable_server_def()->mutable_default_session_config()->MergeFrom(
|
|
server_def.default_session_config());
|
|
for (int i = 0; i < cluster_device_count; i++) {
|
|
if (filtered_device_mask[i]) {
|
|
const auto& da = base_request.cluster_device_attributes(i);
|
|
*request.add_cluster_device_attributes() = da;
|
|
}
|
|
}
|
|
}
|
|
|
|
eager_client->UpdateContextAsync(
|
|
&request, response,
|
|
[i, &statuses, &counter, response](const tensorflow::Status& s) {
|
|
statuses[i] = s;
|
|
delete response;
|
|
counter.DecrementCount();
|
|
});
|
|
}
|
|
counter.Wait();
|
|
for (int i = 0; i < num_remote_workers; i++) {
|
|
TF_RETURN_IF_ERROR(statuses[i]);
|
|
}
|
|
return tensorflow::Status::OK();
|
|
}
|
|
|
|
tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
|
TFE_Context* ctx, bool reset_context) {
|
|
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
|
// server object (which currently CHECK-fails) and we miss the error, instead,
|
|
// we log the error, and then return to allow the user to see the error
|
|
// message.
|
|
#define LOG_AND_RETURN_IF_ERROR(...) \
|
|
do { \
|
|
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
|
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
|
LOG(ERROR) << _status.error_message(); \
|
|
return _status; \
|
|
} \
|
|
} while (0);
|
|
|
|
string worker_name =
|
|
tensorflow::strings::StrCat("/job:", server_def.job_name(),
|
|
"/replica:0/task:", server_def.task_index());
|
|
|
|
// List of current remote workers before updating server_def. Unused if
|
|
// resetting the server_def.
|
|
std::vector<string> curr_remote_workers;
|
|
// List of updated remote workers.
|
|
std::vector<string> remote_workers;
|
|
|
|
// New server created for new server_def. Unused if updating server_def.
|
|
std::unique_ptr<tensorflow::ServerInterface> new_server;
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
tensorflow::GrpcServer* grpc_server;
|
|
if (reset_context) {
|
|
const tensorflow::DeviceMgr* device_mgr =
|
|
AreLocalDevicesCompatible(context, server_def)
|
|
? context->local_device_mgr()
|
|
: nullptr;
|
|
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
|
|
server_def, {device_mgr}, &new_server));
|
|
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
|
LOG_AND_RETURN_IF_ERROR(
|
|
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
|
|
} else {
|
|
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
|
&curr_remote_workers));
|
|
// No need to check the cast here, since `ListRemoteWorkers` already checks
|
|
// if the server is a GRPC server or not.
|
|
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(context->GetServer());
|
|
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
|
|
LOG_AND_RETURN_IF_ERROR(
|
|
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
|
}
|
|
|
|
tensorflow::uint64 context_id = context->GetContextId();
|
|
tensorflow::uint64 context_view_id = context->GetContextViewId();
|
|
if (reset_context) {
|
|
context_id = tensorflow::EagerContext::NewContextId();
|
|
context_view_id = 0;
|
|
// Make master eager context accessible by local eager service, which might
|
|
// receive send tensor requests from remote workers.
|
|
LOG_AND_RETURN_IF_ERROR(
|
|
grpc_server->AddMasterEagerContextToEagerService(context_id, context));
|
|
}
|
|
|
|
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
|
LOG_AND_RETURN_IF_ERROR(
|
|
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
|
&remote_eager_workers));
|
|
|
|
// For cluster update, use a status group to aggregate statuses from
|
|
// * adding and removing remote devices
|
|
// * creating remote contexts on newly added workers
|
|
// * updating remote contexts on existing workers
|
|
// * updating the master context
|
|
// Note that we should not return immediately on errors in the middle of these
|
|
// updates to prevent cluster from having inconsistent context views.
|
|
//
|
|
// Unused if `reset_context` is True.
|
|
tensorflow::StatusGroup sg;
|
|
|
|
// When updating an existing context, populate the following lists with:
|
|
// * added_workers: set(remote_workers) - set(curr_remote_workers)
|
|
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
|
|
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
|
|
// * replaced_workers: workers with the same task names and potentially the
|
|
// same `hostname:port`s, but replaced by different processes
|
|
std::vector<string> added_workers;
|
|
std::vector<string> removed_workers;
|
|
std::vector<string> existing_workers;
|
|
std::vector<string> replaced_workers;
|
|
|
|
// New remote device manager created for new server_def. Unused if updating
|
|
// server_def.
|
|
std::unique_ptr<tensorflow::DynamicDeviceMgr> new_remote_device_mgr;
|
|
tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr;
|
|
if (reset_context) {
|
|
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
|
|
remote_workers, grpc_server->master_env()->worker_cache,
|
|
&new_remote_device_mgr));
|
|
remote_device_mgr = new_remote_device_mgr.get();
|
|
} else {
|
|
context->ClearCachesAndDefaultExecutor();
|
|
// TODO(b/143914772): Potential memory leak if rendezvous has pending
|
|
// tensors for removed / replaced workers.
|
|
|
|
remote_device_mgr = context->GetOwnedRemoteDeviceMgr();
|
|
if (remote_device_mgr == nullptr) {
|
|
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
|
|
"Updating context with an invalid set of remote devices."));
|
|
}
|
|
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
|
|
std::sort(remote_workers.begin(), remote_workers.end());
|
|
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
|
|
&added_workers, &removed_workers,
|
|
&existing_workers);
|
|
sg.Update(GetReplacedFromExistingWorkers(
|
|
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
|
remote_eager_workers.get(), &replaced_workers));
|
|
if (VLOG_IS_ON(1)) {
|
|
VLOG(1) << "Updating cluster with following changes";
|
|
for (const string& w : added_workers) VLOG(1) << " Added worker " << w;
|
|
for (const string& w : removed_workers)
|
|
VLOG(1) << " Removed worker " << w;
|
|
for (const string& w : replaced_workers)
|
|
VLOG(1) << " Replaced worker " << w;
|
|
}
|
|
if (!replaced_workers.empty()) {
|
|
// Treat replaced workers as removed then added back, so that we recreate
|
|
// remote devices and contexts, and re-register functions on those workers
|
|
removed_workers.insert(removed_workers.end(), replaced_workers.begin(),
|
|
replaced_workers.end());
|
|
added_workers.insert(added_workers.end(), replaced_workers.begin(),
|
|
replaced_workers.end());
|
|
for (const string& w : replaced_workers) {
|
|
existing_workers.erase(
|
|
std::remove(existing_workers.begin(), existing_workers.end(), w),
|
|
existing_workers.end());
|
|
}
|
|
}
|
|
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
|
sg.Update(AddRemoteDevicesToMgr(added_workers,
|
|
grpc_server->master_env()->worker_cache,
|
|
remote_device_mgr));
|
|
}
|
|
|
|
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
|
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
|
|
|
|
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
|
|
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
|
|
&local_device_attributes);
|
|
|
|
// This request make sure that we can create Rendezvous properly between
|
|
// Local and Remote context.
|
|
tensorflow::eager::CreateContextRequest base_request;
|
|
for (const auto& da : cluster_device_attributes) {
|
|
*base_request.add_cluster_device_attributes() = da;
|
|
}
|
|
for (const auto& da : local_device_attributes) {
|
|
*base_request.add_cluster_device_attributes() = da;
|
|
}
|
|
|
|
// Initialize remote eager workers.
|
|
if (reset_context) {
|
|
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
|
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
|
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
|
context->LazyCopyFunctionRemoteInputs(), base_request));
|
|
} else {
|
|
// The master's context_view_id will be incremented by one
|
|
// the UpdateRemoteMaster call later. We want all new workers and
|
|
// existing workers to also have the updated context_view_id, so
|
|
// we must set their context_view_id to the existing master's
|
|
// context_view_id + 1.
|
|
sg.Update(CreateRemoteContexts(
|
|
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
|
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
|
context->LazyCopyFunctionRemoteInputs(), base_request));
|
|
if (!existing_workers.empty()) {
|
|
if (VLOG_IS_ON(1)) {
|
|
for (const string& w : existing_workers) {
|
|
VLOG(1) << "Updating cluster with existing worker " << w;
|
|
}
|
|
}
|
|
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
|
|
removed_workers, context_id,
|
|
context_view_id + 1, server_def,
|
|
remote_eager_workers.get(), base_request));
|
|
}
|
|
}
|
|
|
|
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
|
if (reset_context) {
|
|
tensorflow::RemoteRendezvous* r =
|
|
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
|
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
|
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
|
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
|
|
session_name, server_def, base_request.cluster_device_attributes(),
|
|
true));
|
|
TF_RETURN_IF_ERROR(
|
|
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
|
session_name, &worker_session));
|
|
|
|
// Initialize remote tensor communication based on worker session.
|
|
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
|
|
|
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
|
tensorflow::eager::CreateClusterFLR(context_id, context,
|
|
worker_session.get());
|
|
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
|
|
/*is_master=*/true, context);
|
|
|
|
LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster(
|
|
std::move(new_server), grpc_server->worker_env(), worker_session,
|
|
std::move(remote_eager_workers), std::move(new_remote_device_mgr),
|
|
remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr,
|
|
std::move(remote_mgr)));
|
|
|
|
// NOTE: We start the server after all other initialization, because the
|
|
// GrpcServer cannot be destroyed after it is started.
|
|
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
|
} else {
|
|
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
|
|
session_name, server_def, base_request.cluster_device_attributes(),
|
|
/*isolate_session_state=*/true));
|
|
sg.Update(context->UpdateRemoteMaster(context_id,
|
|
std::move(remote_eager_workers),
|
|
added_workers, removed_workers));
|
|
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
|
|
}
|
|
#undef LOG_AND_RETURN_IF_ERROR
|
|
|
|
return tensorflow::Status::OK();
|
|
}
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
|
|
} // namespace
|
|
|
|
extern "C" {
|
|
|
|
TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
|
|
|
|
void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
|
|
size_t proto_len, TF_Status* status) {
|
|
TF_SetConfig(&options->session_options, proto, proto_len, status);
|
|
}
|
|
|
|
void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
|
|
unsigned char enable) {
|
|
options->async = enable;
|
|
}
|
|
|
|
void TFE_ContextOptionsSetDevicePlacementPolicy(
|
|
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
|
|
options->device_placement_policy = policy;
|
|
}
|
|
|
|
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
|
|
|
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
|
if (opts->use_tfrt) {
|
|
#ifdef PLATFORM_GOOGLE
|
|
tfrt::SmallVector<std::string, 4> op_handler_chains;
|
|
tfrt::SmallVector<tensorflow::DeviceAttributes, 4> device_attributes;
|
|
status->status = tfrt::ListOpHandlerChains(
|
|
opts->session_options.options, &op_handler_chains, &device_attributes);
|
|
if (!status->status.ok()) return nullptr;
|
|
return tensorflow::wrap(new tfrt::ContextInterface(
|
|
op_handler_chains, device_attributes, opts->async));
|
|
#else
|
|
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
|
return nullptr;
|
|
#endif
|
|
}
|
|
std::vector<std::unique_ptr<tensorflow::Device>> devices;
|
|
status->status = tensorflow::DeviceFactory::AddDevices(
|
|
opts->session_options.options, "/job:localhost/replica:0/task:0",
|
|
&devices);
|
|
if (!status->status.ok()) return nullptr;
|
|
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
|
|
new tensorflow::StaticDeviceMgr(std::move(devices)));
|
|
|
|
tensorflow::Rendezvous* r =
|
|
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
|
|
|
return tensorflow::wrap(new tensorflow::EagerContext(
|
|
opts->session_options.options,
|
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
|
opts->device_placement_policy),
|
|
static_cast<tensorflow::ContextMirroringPolicy>(opts->mirroring_policy),
|
|
opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(),
|
|
/*device_mgr_owned*/ true, r,
|
|
tensorflow::GetDefaultCustomKernelCreator()));
|
|
}
|
|
|
|
void TFE_DeleteContext(TFE_Context* ctx) {
|
|
if (ctx == nullptr) {
|
|
return;
|
|
}
|
|
|
|
// ctx->RefCountIsOne() should be true here.
|
|
tensorflow::unwrap(ctx)->Release();
|
|
}
|
|
|
|
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
|
|
TF_DeviceList* l = new TF_DeviceList;
|
|
tensorflow::unwrap(ctx)->ListDevices(&l->response);
|
|
return l;
|
|
}
|
|
|
|
void TFE_ContextClearCaches(TFE_Context* ctx) {
|
|
tensorflow::unwrap(ctx)->ClearCachesAndThreadExecutors();
|
|
}
|
|
|
|
// Set server_def on the context, possibly updating it.
|
|
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
|
|
int keep_alive_secs,
|
|
const void* proto,
|
|
size_t proto_len,
|
|
TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"TFE_ContextSetServerDef not supported on mobile");
|
|
#else // !defined(IS_MOBILE_PLATFORM)
|
|
tensorflow::ServerDef server_def;
|
|
if (!server_def.ParseFromArray(proto, proto_len)) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Invalid tensorflow.ServerDef protocol buffer");
|
|
return;
|
|
}
|
|
if (server_def.has_cluster_device_filters()) {
|
|
const auto& cdf = server_def.cluster_device_filters();
|
|
for (const auto& jdf : cdf.jobs()) {
|
|
const string remote_prefix = "/job:" + jdf.name() + "/task:";
|
|
for (const auto& tdf : jdf.tasks()) {
|
|
const int32_t task_index = tdf.first;
|
|
std::vector<string> device_filters(tdf.second.device_filters_size());
|
|
for (int i = 0; i < tdf.second.device_filters_size(); i++) {
|
|
device_filters[i] = tdf.second.device_filters(i);
|
|
}
|
|
const string remote_worker = remote_prefix + std::to_string(task_index);
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
status->status =
|
|
context->SetRemoteDeviceFilters(remote_worker, device_filters);
|
|
}
|
|
}
|
|
}
|
|
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
|
ctx, /*reset_context=*/true);
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
|
|
int keep_alive_secs,
|
|
const void* proto,
|
|
size_t proto_len,
|
|
TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"TFE_ContextSetServerDef not supported on mobile");
|
|
#else // !defined(IS_MOBILE_PLATFORM)
|
|
tensorflow::ServerDef server_def;
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
if (!server_def.ParseFromArray(proto, proto_len)) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Invalid tensorflow.ServerDef protocol buffer");
|
|
return;
|
|
} else if (context->GetContextId() ==
|
|
tensorflow::EagerContext::kInvalidContextId) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Trying to update a context with invalid context id.");
|
|
}
|
|
if (server_def.has_cluster_device_filters()) {
|
|
LOG(WARNING) << "Device filters can only be specified when initializing "
|
|
"the cluster. Any changes in device filters are ignored "
|
|
"when updating the server def.";
|
|
}
|
|
// TODO(haoyuzhang): Check server_def compatibility before the update
|
|
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
|
|
ctx, /*reset_context=*/false);
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
|
|
const char* worker_name,
|
|
TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM)
|
|
status->status = tensorflow::errors::Unimplemented(
|
|
"TFE_ContextSetServerDef not supported on mobile");
|
|
return false;
|
|
#else // !defined(IS_MOBILE_PLATFORM)
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
tensorflow::GrpcServer* grpc_server =
|
|
static_cast<tensorflow::GrpcServer*>(context->GetServer());
|
|
|
|
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
|
status->status = grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
|
&remote_eager_workers);
|
|
if (!status->status.ok()) {
|
|
LOG(ERROR) << "Failed to get client cache for remote workers.";
|
|
return false;
|
|
}
|
|
|
|
// TODO(yuefengz): support partially specified `worker_name`.
|
|
tensorflow::core::RefCountPtr<tensorflow::eager::EagerClient> eager_client;
|
|
status->status = remote_eager_workers->GetClient(worker_name, &eager_client);
|
|
if (!status->status.ok()) {
|
|
return false;
|
|
}
|
|
|
|
// Send a rpc request to the worker to check aliveness.
|
|
tensorflow::eager::KeepAliveRequest request;
|
|
request.set_context_id(context->GetContextId());
|
|
tensorflow::eager::KeepAliveResponse response;
|
|
|
|
tensorflow::Status keep_alive_status;
|
|
tensorflow::Notification done;
|
|
eager_client->KeepAliveAsync(
|
|
&request, &response,
|
|
[&keep_alive_status, &done](const tensorflow::Status& s) {
|
|
keep_alive_status = s;
|
|
done.Notify();
|
|
});
|
|
done.WaitForNotification();
|
|
|
|
status->status = tensorflow::Status::OK();
|
|
|
|
// If `context_id` doesn't exist on the remote worker, an InvalidArgument
|
|
// error will return. But this still indicates that the remote worker is
|
|
// alive.
|
|
if (keep_alive_status.ok() ||
|
|
keep_alive_status.code() == tensorflow::error::INVALID_ARGUMENT) {
|
|
return true;
|
|
} else {
|
|
LOG(INFO) << "Remote worker " << worker_name
|
|
<< " is not alive: " << keep_alive_status.error_message();
|
|
return false;
|
|
}
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
|
|
TF_Status* status) {
|
|
#if defined(IS_MOBILE_PLATFORM)
|
|
status->status = tensorflow::Status::OK();
|
|
#else // !defined(IS_MOBILE_PLATFORM)
|
|
status->status = tensorflow::unwrap(ctx)->AsyncWait();
|
|
#endif // !IS_MOBILE_PLATFORM
|
|
}
|
|
|
|
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
|
|
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
context->SetThreadLocalDevicePlacementPolicy(
|
|
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
|
|
}
|
|
|
|
// Note: this function looks up a thread local policy. So it should be called in
|
|
// the appropriate client thread. In particular, in async mode, it may not be
|
|
// safe to call this function from the async EagerExecutor threads.
|
|
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
|
|
TFE_Context* ctx) {
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
return static_cast<TFE_ContextDevicePlacementPolicy>(
|
|
context->GetDevicePlacementPolicy());
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, TF_Status* status) {
|
|
tensorflow::Tensor tensor;
|
|
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
|
if (!status->status.ok()) return nullptr;
|
|
|
|
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor));
|
|
}
|
|
|
|
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
|
if (h == nullptr) return;
|
|
|
|
tensorflow::profiler::TraceMe activity(
|
|
"TFE_DeleteTensorHandle", tensorflow::profiler::TraceMeLevel::kInfo);
|
|
if (h) {
|
|
tensorflow::unwrap(h)->Release();
|
|
}
|
|
}
|
|
|
|
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
|
|
return static_cast<TF_DataType>(tensorflow::unwrap(h)->DataType());
|
|
}
|
|
|
|
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return -1;
|
|
}
|
|
|
|
int num_dims = -1;
|
|
status->status = tensorflow::unwrap(h)->NumDims(&num_dims);
|
|
return num_dims;
|
|
}
|
|
|
|
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return -1;
|
|
}
|
|
|
|
int64 num_elements = -1;
|
|
status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
|
|
return num_elements;
|
|
}
|
|
|
|
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
|
|
TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return -1;
|
|
}
|
|
|
|
int64 dim = -1;
|
|
status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
|
|
return dim;
|
|
}
|
|
|
|
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
return tensorflow::unwrap(h)->DeviceName(&status->status);
|
|
}
|
|
|
|
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
|
|
TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
return tensorflow::unwrap(h)->BackingDeviceName(&status->status);
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
|
TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
|
|
return tensorflow::wrap(tensorflow::unwrap(h)->Copy());
|
|
}
|
|
|
|
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
|
|
tensorflow::AbstractTensorInterface* t =
|
|
tensorflow::unwrap(h)->Resolve(&status->status);
|
|
if (t == nullptr) {
|
|
return nullptr;
|
|
}
|
|
|
|
return new TF_Tensor{t};
|
|
}
|
|
|
|
void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
tensorflow::TensorHandle* handle =
|
|
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
|
if (VariantDeviceIsCustom(handle->device())) {
|
|
const tensorflow::Tensor* t;
|
|
status->status = handle->Tensor(&t);
|
|
return t->data();
|
|
}
|
|
|
|
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"TFE_TensorHandleDevicePointer may not be called on a ",
|
|
handle->TypeString(), " tensor handle.");
|
|
return nullptr;
|
|
}
|
|
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
|
|
if (device != nullptr) {
|
|
status->status = device->Sync();
|
|
if (!status->status.ok()) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
const tensorflow::Tensor* tensor;
|
|
status->status = handle->Tensor(&tensor);
|
|
if (!status->status.ok()) {
|
|
return nullptr;
|
|
}
|
|
return const_cast<void*>(
|
|
static_cast<const void*>(tensor->tensor_data().data()));
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|
TFE_Context* ctx, const char* device_name, TF_DataType dtype,
|
|
const int64_t* dims, int num_dims, void* data, size_t len,
|
|
void (*deallocator)(void* data, size_t len, void* arg),
|
|
void* deallocator_arg, TF_Status* status) {
|
|
tensorflow::Device* device = nullptr;
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
status->status = context->FindDeviceFromName(device_name, &device);
|
|
tensorflow::CustomDevice* custom_device = nullptr;
|
|
if (!status->status.ok()) {
|
|
status->status =
|
|
context->FindCustomDeviceFromName(device_name, &custom_device);
|
|
if (!status->status.ok()) {
|
|
deallocator(data, len, deallocator_arg);
|
|
return nullptr;
|
|
}
|
|
}
|
|
std::vector<tensorflow::int64> dimvec(num_dims);
|
|
for (int i = 0; i < num_dims; ++i) {
|
|
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
|
}
|
|
|
|
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
|
// the device?
|
|
TF_ManagedBuffer* buf =
|
|
new TF_ManagedBuffer(data, len, deallocator, deallocator_arg,
|
|
/*owns_memory=*/false);
|
|
|
|
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
|
tensorflow::TensorShape(dimvec), buf);
|
|
buf->Unref();
|
|
if (custom_device == nullptr) {
|
|
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
|
std::move(t), device, device, context));
|
|
} else {
|
|
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
|
std::move(t), custom_device, context));
|
|
}
|
|
}
|
|
|
|
// This function will block till the operation that produces `h` has
|
|
// completed. This is only valid on local TFE_TensorHandles. Returns the size in
|
|
// bytes of the memory pointed to by the device pointer returned above.
|
|
size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
|
TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return 0;
|
|
}
|
|
tensorflow::TensorHandle* handle =
|
|
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
|
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"TFE_TensorHandleDeviceMemorySize may not be called on a ",
|
|
handle->TypeString(), " tensor handle.");
|
|
return 0;
|
|
}
|
|
const tensorflow::Tensor* tensor;
|
|
status->status = handle->Tensor(&tensor);
|
|
if (!status->status.ok()) {
|
|
return 0;
|
|
}
|
|
return tensor->TotalBytes();
|
|
}
|
|
|
|
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
|
TF_Status* status) {
|
|
tensorflow::ImmediateExecutionOperation* new_op =
|
|
tensorflow::unwrap(ctx)->CreateOperation();
|
|
status->status = new_op->Reset(op_or_function_name, nullptr);
|
|
if (!status->status.ok()) {
|
|
new_op->Release();
|
|
new_op = nullptr;
|
|
}
|
|
return tensorflow::wrap(new_op);
|
|
}
|
|
|
|
void TFE_DeleteOp(TFE_Op* op) {
|
|
if (op == nullptr) {
|
|
return;
|
|
}
|
|
|
|
tensorflow::unwrap(op)->Release();
|
|
}
|
|
|
|
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
|
status->status = tensorflow::unwrap(op)->SetDeviceName(device_name);
|
|
}
|
|
|
|
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
|
return tensorflow::unwrap(op)->DeviceName().c_str();
|
|
}
|
|
|
|
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
|
|
#ifdef TENSORFLOW_EAGER_USE_XLA
|
|
tensorflow::Status s = tensorflow::unwrap(op)->SetUseXla(enable);
|
|
if (!s.ok()) {
|
|
LOG(ERROR) << "Could not enable XLA compilation for op: " << s;
|
|
}
|
|
#else
|
|
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
|
|
"built with XLA support.";
|
|
#endif // TENSORFLOW_EAGER_USE_XLA
|
|
}
|
|
|
|
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
|
status->status = tensorflow::unwrap(op)->AddInput(tensorflow::unwrap(input));
|
|
}
|
|
|
|
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
|
TF_Status* status) {
|
|
status->status = tensorflow::unwrap(op)->AddInputList(
|
|
{reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
|
tensorflow::unwrap(inputs)),
|
|
static_cast<size_t>(num_inputs)});
|
|
}
|
|
|
|
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
|
unsigned char* is_list, TF_Status* status) {
|
|
TF_AttrType ret = TF_ATTR_INT;
|
|
const tensorflow::AttrTypeMap* attr_types_;
|
|
bool is_function;
|
|
status->status = tensorflow::AttrTypeMapForOp(
|
|
tensorflow::unwrap(op)->Name().c_str(), &attr_types_, &is_function);
|
|
if (!status->status.ok()) {
|
|
return ret;
|
|
}
|
|
status->status =
|
|
tensorflow::AttrTypeByName(*attr_types_, attr_name, &ret, is_list);
|
|
return ret;
|
|
}
|
|
|
|
TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
|
|
const char* op_or_function_name,
|
|
const char* attr_name, unsigned char* is_list,
|
|
TF_Status* status) {
|
|
TF_AttrType ret;
|
|
TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
|
|
if (status->status.ok()) {
|
|
ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
|
|
} else {
|
|
ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
|
|
}
|
|
TFE_DeleteOp(op);
|
|
return ret;
|
|
}
|
|
|
|
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
|
|
size_t length) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrString(
|
|
attr_name, static_cast<const char*>(value), length);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrInt(attr_name, value);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrFloat(attr_name, value);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrBool(attr_name,
|
|
(value == 0) ? false : true);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrType(
|
|
attr_name, static_cast<tensorflow::DataType>(value));
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
|
|
const int num_dims, TF_Status* out_status) {
|
|
out_status->status =
|
|
tensorflow::unwrap(op)->SetAttrShape(attr_name, dims, num_dims);
|
|
}
|
|
|
|
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
|
|
const TFE_Op* value) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrFunction(
|
|
attr_name, tensorflow::unwrap(const_cast<TFE_Op*>(value)));
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
|
|
const char* data, size_t length) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrFunctionName(attr_name, data, length);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
|
|
TF_Status* status) {
|
|
tensorflow::Tensor t;
|
|
status->status = TF_TensorToTensor(tensor, &t);
|
|
tensorflow::TensorInterface interface(t);
|
|
status->status = tensorflow::unwrap(op)->SetAttrTensor(attr_name, &interface);
|
|
}
|
|
|
|
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
|
|
const void* const* values, const size_t* lengths,
|
|
int num_values) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrStringList(attr_name, values, lengths,
|
|
num_values);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
|
|
const float* values, int num_values) {
|
|
auto s =
|
|
tensorflow::unwrap(op)->SetAttrFloatList(attr_name, values, num_values);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
|
|
const int64_t* values, int num_values) {
|
|
auto s =
|
|
tensorflow::unwrap(op)->SetAttrIntList(attr_name, values, num_values);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
|
|
const TF_DataType* values, int num_values) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrTypeList(
|
|
attr_name, reinterpret_cast<const tensorflow::DataType*>(values),
|
|
num_values);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
|
|
const unsigned char* values, int num_values) {
|
|
auto s =
|
|
tensorflow::unwrap(op)->SetAttrBoolList(attr_name, values, num_values);
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
|
const int64_t** dims, const int* num_dims,
|
|
int num_values, TF_Status* out_status) {
|
|
out_status->status = tensorflow::unwrap(op)->SetAttrShapeList(
|
|
attr_name, dims, num_dims, num_values);
|
|
}
|
|
|
|
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
|
|
const TFE_Op** value, int num_values) {
|
|
auto s = tensorflow::unwrap(op)->SetAttrFunctionList(
|
|
attr_name, {reinterpret_cast<const tensorflow::AbstractOperation**>(
|
|
tensorflow::unwrap(value)),
|
|
static_cast<size_t>(num_values)});
|
|
if (!s.ok()) {
|
|
LOG(WARNING) << "Unable to set attribute: " << attr_name;
|
|
}
|
|
}
|
|
|
|
void TFE_OpSetAttrValueProto(const TFE_Op* op, const char* attr_name,
|
|
const void* proto, size_t proto_len,
|
|
TF_Status* status) {
|
|
tensorflow::AttrValue attr_value;
|
|
if (!attr_value.ParseFromArray(proto, proto_len)) {
|
|
status->status =
|
|
tensorflow::errors::InvalidArgument("Unparseable AttrValue proto");
|
|
return;
|
|
}
|
|
if (op == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Got a null or uninitialized `op` argument");
|
|
return;
|
|
}
|
|
tensorflow::EagerOperation* operation =
|
|
OperationFromInterface(tensorflow::unwrap(const_cast<TFE_Op*>(op)));
|
|
operation->MutableAttrs()->Set(attr_name, attr_value);
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
|
|
const char* input_name,
|
|
TF_Status* status) {
|
|
int ret = -1;
|
|
status->status = tensorflow::unwrap(op)->InputLength(input_name, &ret);
|
|
return ret;
|
|
}
|
|
|
|
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
|
const char* output_name,
|
|
TF_Status* status) {
|
|
int ret = -1;
|
|
status->status = tensorflow::unwrap(op)->OutputLength(output_name, &ret);
|
|
return ret;
|
|
}
|
|
|
|
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
|
TF_Status* status) {
|
|
status->status = tensorflow::unwrap(op)->Execute(
|
|
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
|
tensorflow::unwrap(retvals)),
|
|
*num_retvals),
|
|
num_retvals);
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
|
TFE_Context* ctx,
|
|
const char* device_name,
|
|
TF_Status* status) {
|
|
if (h == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
|
return nullptr;
|
|
}
|
|
|
|
auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice(
|
|
tensorflow::unwrap(h), device_name, &status->status);
|
|
if (status->status.ok()) {
|
|
return tensorflow::wrap(result);
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
|
const char* serialized_function_def, size_t size,
|
|
TF_Status* status) {
|
|
tensorflow::FunctionDef function_def;
|
|
if (!function_def.ParseFromArray(serialized_function_def, size)) {
|
|
status->status =
|
|
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
|
|
return;
|
|
}
|
|
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
|
|
}
|
|
|
|
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
|
TF_Status* status) {
|
|
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef);
|
|
}
|
|
|
|
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
|
TF_Status* status) {
|
|
status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
|
|
}
|
|
|
|
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
return context->FindFunctionDef(name) != nullptr;
|
|
}
|
|
|
|
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
context->SetShouldStoreGraphs(true);
|
|
}
|
|
|
|
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
context->SetShouldStoreGraphs(false);
|
|
}
|
|
|
|
} // extern "C"
|
|
|
|
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
|
TF_Status* status) {
|
|
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(t));
|
|
}
|
|
|
|
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
|
TF_Status* status) {
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
status->status = context->Executor().WaitForAllPendingNodes();
|
|
if (!status->status.ok()) return;
|
|
tensorflow::mutex_lock ml(*context->MetadataMu());
|
|
status->status = MessageToBuffer(*context->RunMetadataProto(), buf);
|
|
context->ClearRunMetadata();
|
|
}
|
|
|
|
namespace {
|
|
TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
|
|
TF_Status* status) {
|
|
TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
|
|
for (const auto& attr : func.attr()) {
|
|
if (!status->status.ok()) return nullptr;
|
|
SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
|
|
if (!status->status.ok()) return nullptr;
|
|
}
|
|
return func_op;
|
|
}
|
|
} // namespace
|
|
|
|
void TFE_ContextStartStep(TFE_Context* ctx) {
|
|
tensorflow::unwrap(ctx)->StartStep();
|
|
}
|
|
|
|
void TFE_ContextEndStep(TFE_Context* ctx) {
|
|
tensorflow::unwrap(ctx)->EndStep();
|
|
}
|
|
|
|
const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
|
|
return tensorflow::wrap(
|
|
&OperationFromInterface(tensorflow::unwrap(op))->Attrs());
|
|
}
|
|
|
|
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
|
tensorflow::EagerOperation* operation =
|
|
OperationFromInterface(tensorflow::unwrap(op));
|
|
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
|
destination->CopyAttributes(*tensorflow::unwrap(attrs));
|
|
}
|
|
|
|
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
|
TF_Status* status) {
|
|
tensorflow::NameAttrList name_and_attrs;
|
|
tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr());
|
|
name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name());
|
|
status->status = MessageToBuffer(name_and_attrs, buf);
|
|
}
|
|
|
|
namespace tensorflow {
|
|
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
|
const tensorflow::AttrValue& default_value,
|
|
const char* attr_name, TF_Status* status) {
|
|
switch (default_value.value_case()) {
|
|
case tensorflow::AttrValue::kS: {
|
|
const string& v = default_value.s();
|
|
TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
|
|
break;
|
|
}
|
|
case tensorflow::AttrValue::kI:
|
|
TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
|
|
break;
|
|
case tensorflow::AttrValue::kF:
|
|
TFE_OpSetAttrFloat(op, attr_name, default_value.f());
|
|
break;
|
|
case tensorflow::AttrValue::kB:
|
|
TFE_OpSetAttrBool(op, attr_name, default_value.b());
|
|
break;
|
|
case tensorflow::AttrValue::kType:
|
|
TFE_OpSetAttrType(op, attr_name,
|
|
static_cast<TF_DataType>(default_value.type()));
|
|
break;
|
|
case tensorflow::AttrValue::kShape: {
|
|
const auto& tensor_shape = default_value.shape();
|
|
if (tensor_shape.unknown_rank()) {
|
|
TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
|
|
} else {
|
|
const auto num_dims = tensor_shape.dim_size();
|
|
std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
|
|
for (int i = 0; i < num_dims; ++i) {
|
|
dims[i] = tensor_shape.dim(i).size();
|
|
}
|
|
TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
|
|
}
|
|
} break;
|
|
case tensorflow::AttrValue::kFunc: {
|
|
const auto func_op = GetFunc(ctx, default_value.func(), status);
|
|
if (!status->status.ok()) return;
|
|
// TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
|
|
// require TFE_Op* and just convert it internally a NameAttrValue, so
|
|
// consider adding an overload to the C API to make this case easier.
|
|
TFE_OpSetAttrFunction(op, attr_name, func_op);
|
|
TFE_DeleteOp(func_op);
|
|
} break;
|
|
case tensorflow::AttrValue::kList:
|
|
TF_FALLTHROUGH_INTENDED;
|
|
case tensorflow::AttrValue::kTensor:
|
|
TF_FALLTHROUGH_INTENDED;
|
|
case tensorflow::AttrValue::kPlaceholder:
|
|
TF_FALLTHROUGH_INTENDED;
|
|
case tensorflow::AttrValue::VALUE_NOT_SET:
|
|
TF_SetStatus(
|
|
status, TF_UNIMPLEMENTED,
|
|
tensorflow::strings::StrCat("Unable to get setfor default value: ",
|
|
default_value.DebugString())
|
|
.data());
|
|
}
|
|
}
|
|
} // namespace tensorflow
|
|
|
|
namespace {
|
|
class CustomDeviceAPI : public tensorflow::CustomDevice {
|
|
public:
|
|
CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
|
|
string name)
|
|
: context_(context), device_(device), info_(info), name_(name) {}
|
|
|
|
~CustomDeviceAPI() override { device_.delete_device(info_); }
|
|
|
|
const string& name() override { return name_; }
|
|
|
|
tensorflow::Status CopyTensorToDevice(
|
|
tensorflow::TensorHandle* handle,
|
|
tensorflow::TensorHandle** result) override {
|
|
handle->Ref();
|
|
TF_Status status;
|
|
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
|
|
context_, tensorflow::wrap(handle), &status, info_);
|
|
handle->Release();
|
|
if (!status.status.ok()) return status.status;
|
|
*result = tensorflow::TensorHandleFromInterface(
|
|
tensorflow::unwrap(result_handle));
|
|
(*result)->Ref();
|
|
TFE_DeleteTensorHandle(result_handle);
|
|
return status.status;
|
|
}
|
|
|
|
tensorflow::Status CopyTensorFromDevice(
|
|
tensorflow::TensorHandle* handle,
|
|
const tensorflow::string& target_device_name,
|
|
tensorflow::TensorHandle** result) override {
|
|
TF_Status status;
|
|
handle->Ref();
|
|
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
|
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
|
|
info_);
|
|
handle->Release();
|
|
if (!status.status.ok()) return status.status;
|
|
*result = tensorflow::TensorHandleFromInterface(
|
|
tensorflow::unwrap(result_handle));
|
|
(*result)->Ref();
|
|
TFE_DeleteTensorHandle(result_handle);
|
|
return status.status;
|
|
}
|
|
|
|
tensorflow::Status Execute(tensorflow::EagerOperation* op,
|
|
tensorflow::TensorHandle** retvals,
|
|
int* num_retvals) override {
|
|
std::vector<TFE_TensorHandle*> inputs;
|
|
inputs.reserve(op->Inputs().size());
|
|
for (int i = 0; i < op->Inputs().size(); ++i) {
|
|
op->Inputs()[i]->Ref();
|
|
inputs.push_back(tensorflow::wrap(op->Inputs()[i]));
|
|
}
|
|
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
|
TF_Status status;
|
|
device_.execute(context_, inputs.size(), inputs.data(), op->Name().c_str(),
|
|
wrap(&op->Attrs()), num_retvals, outputs.data(), &status,
|
|
info_);
|
|
if (status.status.ok()) {
|
|
for (int i = 0; i < *num_retvals; ++i) {
|
|
retvals[i] = tensorflow::TensorHandleFromInterface(
|
|
tensorflow::unwrap(outputs[i]));
|
|
retvals[i]->Ref();
|
|
TFE_DeleteTensorHandle(outputs[i]);
|
|
}
|
|
}
|
|
|
|
for (auto inp : inputs) {
|
|
TFE_DeleteTensorHandle(inp);
|
|
}
|
|
return status.status;
|
|
}
|
|
|
|
private:
|
|
TFE_Context* context_;
|
|
TFE_CustomDevice device_;
|
|
void* info_;
|
|
string name_;
|
|
};
|
|
} // namespace
|
|
|
|
extern "C" {
|
|
|
|
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
|
const char* device_name, void* device_info,
|
|
TF_Status* status) {
|
|
auto custom_device =
|
|
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
|
|
tensorflow::EagerContext* context =
|
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
|
status->status =
|
|
context->RegisterCustomDevice(device_name, std::move(custom_device));
|
|
}
|
|
|
|
} // extern "C"
|