Open sourcing some TPU-related work

PiperOrigin-RevId: 315431095
Change-Id: I734632c0e5723dfca37acf53bbbd2b378b04c95d
This commit is contained in:
Frank Chen 2020-06-08 23:59:12 -07:00 committed by TensorFlower Gardener
parent d9e5e2f7b3
commit 13c09da422
57 changed files with 7422 additions and 0 deletions

View File

@ -0,0 +1,55 @@
# Contains graph rewrites for TPU runtimes and optimizations.
package(
default_visibility = [
"//tensorflow/core/tpu:__subpackages__",
"//tensorflow/stream_executor/tpu:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "distributed_tpu_configuration_rewrite_registration",
srcs = ["distributed_tpu_configuration_rewrite_registration.cc"],
deps = [
":distributed_tpu_configuration_rewrite_pass",
"//tensorflow/core:core_cpu",
],
alwayslink = 1,
)
cc_library(
name = "distributed_tpu_configuration_rewrite_pass",
srcs = [
"distributed_tpu_configuration_rewrite_pass.cc",
],
hdrs = [
"distributed_tpu_configuration_rewrite_pass.h",
],
deps = [
":distributed_tpu_rewrite_helpers",
"//tensorflow/cc:scope",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
"//tensorflow/core/tpu:tpu_init_mode",
"//tensorflow/core/tpu/kernels:tpu_compile_op_options",
],
)
cc_library(
name = "distributed_tpu_rewrite_helpers",
srcs = ["distributed_tpu_rewrite_helpers.cc"],
hdrs = ["distributed_tpu_rewrite_helpers.h"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/tpu:tpu_defs",
],
)

View File

@ -0,0 +1,402 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Configuration for distributed TPU jobs
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
#include <unordered_map>
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
#include "tensorflow/core/tpu/tpu_init_mode.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
namespace {
constexpr char kIdentityOp[] = "Identity";
constexpr char kConfigureOp[] = "ConfigureDistributedTPU";
constexpr char kInternalConfigureOp[] = "_ConfigureDistributedTPU";
constexpr char kWaitOp[] = "_WaitForDistributedTPU";
constexpr char kHostConfigureOp[] = "_InitializeHostForDistributedTPU";
constexpr char kGlobalTPUArrayOp[] = "_SetGlobalTPUArray";
constexpr char kShutdownOp[] = "ShutdownDistributedTPU";
constexpr char kInternalShutdownOp[] = "_ShutdownDistributedTPU";
constexpr char kHostDisconnectOp[] = "_DisconnectHostFromDistributedTPUSystem";
constexpr char kEmbeddingConfigurationAttr[] = "embedding_config";
constexpr int kDefaultStartupTimeout = 20;
Status AddConfigurationNode(const string& configuration_device_name,
int number_of_hosts, Graph* graph,
bool enable_whole_mesh_compilations,
Node** configuration_node) {
NodeDef config_def;
config_def.set_name(graph->NewName("configure_distributed_tpu"));
config_def.set_op(kInternalConfigureOp);
config_def.set_device(configuration_device_name);
AddNodeAttr("N", number_of_hosts, &config_def);
AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
&config_def);
// TODO(shikharagarwal): Fill with appropriate original node debug info.
Status status;
*configuration_node = graph->AddNode(config_def, &status);
if (!status.ok()) {
return status;
}
(*configuration_node)->set_assigned_device_name(configuration_device_name);
return Status::OK();
}
Status AddHostConfigNode(const string& host_device_name,
Node* configuration_node, Graph* graph,
bool enable_whole_mesh_compilations,
Node** host_configuration_node) {
NodeDef host_config_def;
host_config_def.set_name(graph->NewName("configure_tpu_host"));
host_config_def.set_op(kHostConfigureOp);
host_config_def.set_device(host_device_name);
AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
&host_config_def);
MergeDebugInfo(NodeDebugInfo(configuration_node->def()), &host_config_def);
Status status;
*host_configuration_node = graph->AddNode(host_config_def, &status);
if (!status.ok()) {
return status;
}
(*host_configuration_node)->set_assigned_device_name(host_device_name);
graph->AddEdge(configuration_node, 0, *host_configuration_node, 0);
return Status::OK();
}
Status AddWaitNode(const string& configuration_device_name,
const std::vector<Node*>& host_configuration_nodes,
Graph* graph, Node** wait_node) {
NodeDef wait_def;
wait_def.set_name(graph->NewName("wait_for_distributed_tpu_system"));
wait_def.set_op(kWaitOp);
wait_def.set_device(configuration_device_name);
AddNodeAttr("N", static_cast<int32>(host_configuration_nodes.size()),
&wait_def);
AddNodeAttr("startup_timeout_sec", kDefaultStartupTimeout, &wait_def);
if (!host_configuration_nodes.empty()) {
MergeDebugInfo(NodeDebugInfo(host_configuration_nodes[0]->def()),
&wait_def);
}
Status status;
*wait_node = graph->AddNode(wait_def, &status);
if (!status.ok()) {
return status;
}
(*wait_node)->set_assigned_device_name(configuration_device_name);
// Get the inputs from the host configuration nodes.
for (int i = 0; i < host_configuration_nodes.size(); ++i) {
graph->AddEdge(host_configuration_nodes[i], 0, *wait_node, i);
}
return Status::OK();
}
Status AddGlobalTPUArrayNode(const string& host_device_name, Node* wait_node,
Graph* graph, Node** global_tpu_array_node) {
NodeDef global_tpu_array_def;
global_tpu_array_def.set_name(graph->NewName("set_global_tpu_array"));
global_tpu_array_def.set_op(kGlobalTPUArrayOp);
global_tpu_array_def.set_device(host_device_name);
MergeDebugInfo(NodeDebugInfo(wait_node->def()), &global_tpu_array_def);
Status status;
*global_tpu_array_node = graph->AddNode(global_tpu_array_def, &status);
if (!status.ok()) {
return status;
}
(*global_tpu_array_node)->set_assigned_device_name(host_device_name);
graph->AddEdge(wait_node, 0, *global_tpu_array_node, 0);
return Status::OK();
}
Status AddSynchronizationNode(
const NodeDef& sync_node_def, const string& device_name,
const std::vector<Node*>& global_array_id_nodes, Node* wait_node,
const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
output_dependencies,
Graph* graph) {
NodeDef sync_def;
sync_def.set_name(sync_node_def.name());
sync_def.set_op(kIdentityOp);
sync_def.set_device(device_name);
AddNodeAttr("T", DT_STRING, &sync_def);
MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def);
Status status;
Node* sync_node = graph->AddNode(sync_def, &status);
if (!status.ok()) {
return status;
}
sync_node->set_assigned_device_name(device_name);
// Add control edges from the global array id nodes.
for (auto node : global_array_id_nodes) {
graph->AddControlEdge(node, sync_node);
}
// Forward the data from the wait node.
graph->AddEdge(wait_node, 0, sync_node, 0);
// Replace the output edges.
for (const DistributedTPURewriteHelpers::OutputDependency& dep :
output_dependencies) {
if (dep.dst_input == Graph::kControlSlot) {
graph->AddControlEdge(sync_node, dep.dst);
} else {
graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input);
}
}
return Status::OK();
}
Status AddShutdownNode(
const NodeDef& shutdown_node_def, const string& shutdown_device_name,
const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
output_dependencies,
Graph* graph, Node** shutdown_node) {
NodeDef shutdown_def;
shutdown_def.set_name(shutdown_node_def.name());
shutdown_def.set_op(kInternalShutdownOp);
shutdown_def.set_device(shutdown_device_name);
MergeDebugInfo(NodeDebugInfo(shutdown_node_def), &shutdown_def);
Status status;
*shutdown_node = graph->AddNode(shutdown_def, &status);
if (!status.ok()) {
return status;
}
(*shutdown_node)->set_assigned_device_name(shutdown_device_name);
// Replace the output control edges.
for (const DistributedTPURewriteHelpers::OutputDependency& dep :
output_dependencies) {
if (dep.dst_input != Graph::kControlSlot) {
return errors::Internal("Shutdown node had non-control edge output");
}
graph->AddControlEdge(*shutdown_node, dep.dst);
}
return Status::OK();
}
Status AddHostDisconnectNode(const string& host_device_name,
const std::vector<Node*>& input_dependencies,
Node* post_disconnect_node, int output_index,
Graph* graph) {
NodeDef host_disconnect_def;
host_disconnect_def.set_name(graph->NewName("disconnect_tpu_host"));
host_disconnect_def.set_op(kHostDisconnectOp);
host_disconnect_def.set_device(host_device_name);
MergeDebugInfo(NodeDebugInfo(post_disconnect_node->def()),
&host_disconnect_def);
Status status;
Node* host_disconnect_node = graph->AddNode(host_disconnect_def, &status);
if (!status.ok()) {
return status;
}
host_disconnect_node->set_assigned_device_name(host_device_name);
// Replace the input control edges.
for (Node* src_node : input_dependencies) {
graph->AddControlEdge(src_node, host_disconnect_node);
}
if (output_index == -1) {
graph->AddControlEdge(host_disconnect_node, post_disconnect_node);
} else {
graph->AddEdge(host_disconnect_node, 0, post_disconnect_node, output_index);
}
return Status::OK();
}
} // namespace
Status DistributedTPUConfigurationRewritePass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "DistributedTPUConfigurationRewritePass::Run";
Graph* graph = options.graph->get();
if (VLOG_IS_ON(1)) {
DumpGraphToFile("distributed_tpu_configuration_before", *graph,
options.flib_def);
}
// This pass can only run in the session master, which should fill
// in the device_set field to the options.
TF_RET_CHECK(options.device_set != nullptr);
TF_RETURN_IF_ERROR(
DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
kConfigureOp, graph, *options.device_set,
[](const NodeDef& configuration_node_def,
const string& configuration_device_name,
const std::vector<Device*>& host_devices,
const std::vector<Node*>& input_dependencies,
const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
output_dependencies,
Graph* graph) -> Status {
const std::string& embedding_attr_string = GetNodeAttrString(
AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr);
if (!embedding_attr_string.empty()) {
return errors::InvalidArgument("embedding_config must be empty.");
}
bool is_global_init = false;
bool enable_whole_mesh_compilations = false;
TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
"is_global_init", &is_global_init));
TryGetNodeAttr(configuration_node_def,
"enable_whole_mesh_compilations",
&enable_whole_mesh_compilations);
TF_RETURN_IF_ERROR(SetTPUInitMode(
is_global_init ? TPUInitMode::kGlobal : TPUInitMode::kRegular));
bool compilation_failure_closes_chips;
TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
"compilation_failure_closes_chips",
&compilation_failure_closes_chips));
internal::SetTpuCompilationFailureClosesChips(
compilation_failure_closes_chips);
// Add the global TPU system configuration node.
Node* configuration_node;
TF_RETURN_IF_ERROR(AddConfigurationNode(
configuration_device_name, host_devices.size(), graph,
enable_whole_mesh_compilations, &configuration_node));
// Add the host disconnect nodes.
for (int i = 0; i < host_devices.size(); ++i) {
const auto host_device = host_devices[i];
TF_RETURN_IF_ERROR(
AddHostDisconnectNode(host_device->name(), input_dependencies,
configuration_node, i, graph));
}
// Add the host configuration nodes.
std::vector<Node*> host_configuration_nodes;
for (const auto host_device : host_devices) {
Node* host_configuration_node;
TF_RETURN_IF_ERROR(AddHostConfigNode(
host_device->name(), configuration_node, graph,
enable_whole_mesh_compilations, &host_configuration_node));
host_configuration_nodes.push_back(host_configuration_node);
}
// Add the node to wait for the system configuration to
// stabilize. Use the name of the original dummy Op in case it was
// the target of a Session::Run call.
Node* wait_node;
TF_RETURN_IF_ERROR(AddWaitNode(configuration_device_name,
host_configuration_nodes, graph,
&wait_node));
// Add the nodes to set the global TPU ids at each host.
std::vector<Node*> global_array_id_nodes;
for (const auto host_device : host_devices) {
Node* global_array_id_node;
TF_RETURN_IF_ERROR(AddGlobalTPUArrayNode(host_device->name(),
wait_node, graph,
&global_array_id_node));
global_array_id_nodes.push_back(global_array_id_node);
}
if (host_devices.empty()) {
return errors::InvalidArgument("TPU job contains no CPU devices");
}
TF_RET_CHECK(!host_devices.empty());
TF_RETURN_IF_ERROR(AddSynchronizationNode(
configuration_node_def, host_devices.front()->name(),
global_array_id_nodes, wait_node, output_dependencies, graph));
return Status::OK();
}));
if (VLOG_IS_ON(1)) {
DumpGraphToFile("distributed_tpu_configuration_after", *graph,
options.flib_def);
}
VLOG(1) << "DistributedTPUConfigurationRewritePass::Run() finished";
return Status::OK();
}
Status DistributedTPUShutdownRewritePass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "DistributedTPUShutdownRewritePass::Run";
Graph* graph = options.graph->get();
if (VLOG_IS_ON(1)) {
DumpGraphToFile("distributed_tpu_shutdown_before", *graph,
options.flib_def);
}
// This pass can only run in the session master, which should fill
// in the device_set field to the options.
TF_RET_CHECK(options.device_set != nullptr);
TF_RETURN_IF_ERROR(
DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
kShutdownOp, graph, *options.device_set,
[](const NodeDef& shutdown_node_def,
const string& shutdown_device_name,
const std::vector<Device*>& host_devices,
const std::vector<Node*>& input_dependencies,
const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
output_dependencies,
Graph* graph) -> Status {
Node* shutdown_node;
TF_RETURN_IF_ERROR(
AddShutdownNode(shutdown_node_def, shutdown_device_name,
output_dependencies, graph, &shutdown_node));
// Add the host disconnect nodes.
for (const auto host_device : host_devices) {
TF_RETURN_IF_ERROR(
AddHostDisconnectNode(host_device->name(), input_dependencies,
shutdown_node, -1, graph));
}
return Status::OK();
}));
if (VLOG_IS_ON(1)) {
DumpGraphToFile("distributed_tpu_shutdown_after", *graph, options.flib_def);
}
VLOG(1) << "DistributedTPUShutdownRewritePass::Run() finished";
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Rewrites ConfigureDistributedTPU Op into a graph that configures each host.
//
// See the comment at the top of
// third_party/tensorflow/core/ops/tpu_configuration_ops.cc to see the
// sequence of Ops used to configure a distributed TPU system.
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
// Replaces dummy ConfigureDistributedTPU Ops assigned to TPU_SYSTEM
// devices with _ConfigureDistributedTPU and _WaitForDistributedTPU
// Ops on TPU_SYSTEM, and _InitializeHostForDistributedTPU on the CPU
// device of each host in the same job as the given TPU_SYSTEM device.
class DistributedTPUConfigurationRewritePass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
// Replaces dummy ShutdownDistributedTPU Ops assigned to TPU_SYSTEM
// devices with _ShutdownDistributedTPU Ops on TPU_SYSTEM and
// _DisconnectHostFromDistributedTPUSystem on the CPU device of each
// host in the same job as the given TPU_SYSTEM device.
class DistributedTPUShutdownRewritePass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
namespace tensorflow {
namespace {
// This pass removes the TPUEmbeddingConfiguration in ConfigureDistributedTPU.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20,
DistributedTPUConfigurationRewritePass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20,
DistributedTPUShutdownRewritePass);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,255 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Helper functions for TPU rewrite passes.
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
#include <vector>
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
// LINT.IfChange
Status DistributedTPURewriteHelpers::GetSystemDevice(
const string& system_spec_string, const DeviceSet& device_set,
DeviceNameUtils::ParsedName* system_spec, Device** system_device) {
if (!DeviceNameUtils::ParseFullName(system_spec_string, system_spec)) {
system_spec->Clear();
}
// Callers may have relied on an Op only being registered on TPU_SYSTEM
// devices to ensure the Op is placed there. Augment the device spec to make
// the device type explicit.
if (!system_spec->has_type || system_spec->type != DEVICE_TPU_SYSTEM) {
system_spec->type = DEVICE_TPU_SYSTEM;
system_spec->has_type = true;
system_spec->id = 0;
system_spec->has_id = true;
}
std::vector<Device*> system_devices;
device_set.FindMatchingDevices(*system_spec, &system_devices);
if (system_devices.empty()) {
if (system_spec_string.empty()) {
return errors::InvalidArgument(
"No TPU_SYSTEM device found. Please ensure that you're connected to "
"a host with a TPU_SYSTEM device.");
}
return errors::InvalidArgument("No matching devices found for '",
system_spec_string, "'");
} else if (system_devices.size() > 1) {
// Validate that all system devices are part of the same job.
std::unordered_set<string> job_names;
for (auto device : system_devices) {
const auto& parsed_name = device->parsed_name();
TF_RET_CHECK(parsed_name.has_job);
job_names.insert(parsed_name.job);
}
if (job_names.size() > 1) {
return errors::InvalidArgument(
"System devices cannot be part "
"of multiple different jobs. Found: ",
str_util::Join(job_names, ","));
}
// Identify the lexicographically first device from the list of
// valid TPU SYSTEM devices, so that every process in the same
// 'cluster' definition uses the same system device.
std::sort(system_devices.begin(), system_devices.end(),
[](Device* i, Device* j) {
auto i_name = i->parsed_name();
auto j_name = j->parsed_name();
if (i_name.replica != j_name.replica) {
return i_name.replica < j_name.replica;
}
return i_name.task < j_name.task;
});
}
*system_device = system_devices[0];
if (!DeviceNameUtils::ParseFullName((*system_device)->name(), system_spec)) {
return errors::InvalidArgument("Unable to re-parse system device name ",
(*system_device)->name(),
" as a device spec.");
}
return Status::OK();
}
// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
// LINT.IfChange
Status DistributedTPURewriteHelpers::GetHostSystemDevices(
const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
std::vector<Device*>* host_system_devices) {
DeviceNameUtils::ParsedName host_spec;
if (system_spec.has_job) {
// The system Op has been explicitly assigned to a job, so we want
// all the hosts in that job.
CHECK(DeviceNameUtils::ParseFullName(
strings::StrCat("/job:", system_spec.job, "/device:", DEVICE_TPU_SYSTEM,
":0"),
&host_spec));
} else {
// The system Op has not been explicitly assigned to a
// job, so take all hosts in the system. There will be a runtime
// error if some of those hosts don't contain TPU devices.
CHECK(DeviceNameUtils::ParseFullName(
strings::StrCat("/device:", DEVICE_TPU_SYSTEM, ":0"), &host_spec));
}
device_set.FindMatchingDevices(host_spec, host_system_devices);
TF_RET_CHECK(!host_system_devices->empty())
<< "No hosts found matching device spec "
<< DeviceNameUtils::ParsedNameToString(host_spec);
// Check that all the devices belong to the same job.
TF_RET_CHECK((*host_system_devices)[0]->parsed_name().has_job);
const string& job_name = (*host_system_devices)[0]->parsed_name().job;
int replica = (*host_system_devices)[0]->parsed_name().replica;
for (const auto host_device : *host_system_devices) {
const auto& parsed_name = host_device->parsed_name();
TF_RET_CHECK(parsed_name.has_job);
if (parsed_name.job != job_name) {
return errors::InvalidArgument(
"All TPU host devices must be in the same job");
}
TF_RET_CHECK(parsed_name.has_replica);
if (parsed_name.replica != replica) {
return errors::InvalidArgument(
"All TPU host devices must be in the same replica");
}
}
// Sort the devices by replica and then task.
std::sort(host_system_devices->begin(), host_system_devices->end(),
[](Device* i, Device* j) {
auto i_name = i->parsed_name();
auto j_name = j->parsed_name();
return i_name.task < j_name.task;
});
return Status::OK();
}
// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
// LINT.IfChange
Status DistributedTPURewriteHelpers::GetTPUDevices(
const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
int* num_tpus_per_host, std::vector<std::vector<Device*>>* tpu_devices) {
// GetHostSystemDevices returns the CPU device on each host that is
// going to be used for executing TPU code.
std::vector<Device*> host_system_devices;
TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetHostSystemDevices(
system_spec, device_set, &host_system_devices));
// Enumerate all the physical devices. Enumerate devices on task 0,
// then task 1, etc.
std::sort(host_system_devices.begin(), host_system_devices.end(),
[](Device* i, Device* j) {
return i->parsed_name().task < j->parsed_name().task;
});
*num_tpus_per_host = 0;
tpu_devices->clear();
tpu_devices->reserve(host_system_devices.size());
for (const auto device : host_system_devices) {
// Make a copy of the parsed name because we are going to change it.
DeviceNameUtils::ParsedName device_spec = device->parsed_name();
device_spec.has_type = true;
device_spec.type = "TPU";
// Enumerate all the available TPUs.
device_spec.has_id = false;
std::vector<Device*> host_tpu_devices;
device_set.FindMatchingDevices(device_spec, &host_tpu_devices);
// Sort the devices by device id.
std::sort(host_tpu_devices.begin(), host_tpu_devices.end(),
[](Device* i, Device* j) {
return i->parsed_name().id < j->parsed_name().id;
});
if (tpu_devices->empty()) {
// First iteration: set *num_tpus_per_host to the number of TPUs on the
// first host.
*num_tpus_per_host = host_tpu_devices.size();
} else if (*num_tpus_per_host != host_tpu_devices.size()) {
// Subsequent iterations: check the number of TPUs match the number on
// the first host.
return errors::InvalidArgument(
"Mismatched number of TPU devices in cluster ", *num_tpus_per_host,
" vs. ", host_tpu_devices.size());
}
tpu_devices->push_back(std::move(host_tpu_devices));
}
return Status::OK();
}
// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
Status DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
const string& node_type, Graph* graph, const DeviceSet& device_set,
const std::function<
Status(const NodeDef& configuration_node_def,
const string& configuration_device_name,
const std::vector<Device*>& host_devices,
const std::vector<Node*>& input_dependencies,
const std::vector<OutputDependency>& output_dependencies,
Graph* graph)>& action) {
// Find all the matching nodes before mutating the graph.
std::vector<Node*> nodes;
for (Node* node : graph->nodes()) {
if (node->type_string() == node_type) {
nodes.push_back(node);
}
}
for (Node* node : nodes) {
string spec_string = node->requested_device();
DeviceNameUtils::ParsedName spec;
Device* device;
TF_RETURN_IF_ERROR(
GetSystemDevice(spec_string, device_set, &spec, &device));
const string& device_name = device->name();
std::vector<Device*> host_devices;
TF_RETURN_IF_ERROR(GetHostSystemDevices(spec, device_set, &host_devices));
std::vector<Node*> input_dependencies;
for (const Edge* edge : node->in_edges()) {
// Config ops have no inputs, so all edges must be control edges.
CHECK(edge->IsControlEdge());
input_dependencies.push_back(edge->src());
}
std::vector<OutputDependency> output_dependencies;
for (const Edge* edge : node->out_edges()) {
OutputDependency dep;
dep.src_output = edge->src_output();
dep.dst = edge->dst();
dep.dst_input = edge->dst_input();
output_dependencies.push_back(dep);
}
NodeDef node_def = node->def();
// Remove the node now so we can insert a new node with the same
// name inside the action.
graph->RemoveNode(node);
TF_RETURN_IF_ERROR(action(node_def, device_name, host_devices,
input_dependencies, output_dependencies, graph));
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,98 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Helper functions for TPU rewrite passes.
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
class DistributedTPURewriteHelpers {
public:
// Given a user-assigned device string, system_spec_string, parse it into
// system_spec. Verify that the device type is either TPU_SYSTEM or
// unassigned, and in the latter case set it to TPU_SYSTEM:0. Having set the
// type, verify that the spec matches a unique device in device_set, and
// return that device in system_device. The normal use case is for
// system_spec_string to identify the TPU_SYSTEM on replica 0, task 0 of the
// job that contains the TPU hardware.
// TODO(b/110910013): Possibly remove the tpu system device.
static Status GetSystemDevice(const string& system_spec_string,
const DeviceSet& device_set,
DeviceNameUtils::ParsedName* system_spec,
Device** system_device);
// Given a parsed system spec (e.g., the one returned above from
// GetSystemDeviceName), return in host_devices the TPU_SYSTEM:0 device on
// every host in the spec's job. If the spec does not include an explicit job,
// "localhost" is used. Returns an error if system_spec matches devices from
// a multiple jobs or replicas.
static Status GetHostSystemDevices(
const DeviceNameUtils::ParsedName& system_spec,
const DeviceSet& device_set, std::vector<Device*>* host_system_devices);
// Given a parsed system spec (e.g., the one returned above from
// GetSystemDeviceName), sets `*tpu_devices` to a per-host vector of the TPU
// devices on every host in the spec's job. If the spec does not include an
// explicit job, "localhost" is used. Sets `*num_tpus_per_host` to the number
// of TPU devices in each host, and verifies that each host in the job has
// the same number of TPU devices.
// Returns an error if system_spec matches devices from a multiple jobs or
// replicas.
static Status GetTPUDevices(const DeviceNameUtils::ParsedName& system_spec,
const DeviceSet& device_set,
int* num_tpus_per_host,
std::vector<std::vector<Device*>>* tpu_devices);
// Perform 'action' on every node in 'graph' of type
// 'node_type'. This function is designed for use with configuration
// Ops that have no inputs or outputs. The arguments passed to 'action' are:
// 'configuration_node_name': the name of the node that matched
// 'configuration_device_name': the name of the device that the
// matching node is placed on
// 'host_devices': the set of TPU_SYSTEM devices on hosts with TPUs that are
// in the same system as the node that matched.
// 'input_dependencies': the set of nodes that have control edges to
// the matching node.
// 'output_dependencies': the set of output port, destination node, input port
// triples that have edges from the matching node. Input port is
// Graph::kControlSlot for a control edge.
// 'graph': the graph being mutated.
struct OutputDependency {
int src_output;
Node* dst;
int dst_input;
};
static Status ForConfigurationNodeMatchingType(
const string& node_type, Graph* graph, const DeviceSet& device_set,
const std::function<
Status(const NodeDef& configuration_node_def,
const string& configuration_device_name,
const std::vector<Device*>& host_devices,
const std::vector<Node*>& input_dependencies,
const std::vector<OutputDependency>& output_dependencies,
Graph* graph)>& action);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_

View File

@ -0,0 +1,288 @@
# TPU Kernel Implementations
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_proto_library_cc",
)
package(
default_visibility = [
"//tensorflow/core/tpu:__subpackages__",
"//tensorflow/stream_executor/tpu:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "tpu_compile_op_options",
srcs = ["tpu_compile_op_options.cc"],
hdrs = ["tpu_compile_op_options.h"],
)
cc_library(
name = "tpu_configuration_ops",
srcs = ["tpu_configuration_ops.cc"],
hdrs = ["tpu_configuration_ops.h"],
deps = [
":tpu_mesh_state_interface",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:refcount",
"//tensorflow/core/tpu:tpu_config_c_api",
"//tensorflow/core/tpu:tpu_configuration",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/core/tpu:tpu_library_loader",
"//tensorflow/stream_executor/tpu:proto_helper",
],
alwayslink = 1,
)
cc_library(
name = "tpu_compile_c_api_hdrs",
hdrs = ["tpu_compile_c_api.h"],
deps = [
":tpu_mesh_state_c_api",
"//tensorflow/c:tf_datatype",
"//tensorflow/stream_executor/tpu:proto_helper",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
],
)
tf_proto_library_cc(
name = "tpu_executable_info_proto",
srcs = ["tpu_executable_info.proto"],
cc_api_version = 2,
protodeps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:protos_all",
],
)
tf_proto_library_cc(
name = "tpu_compile_proto",
srcs = ["tpu_compile.proto"],
cc_api_version = 2,
protodeps = [
":tpu_executable_info_proto",
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:protos_all",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto",
],
)
cc_library(
name = "tpu_compilation_cache_key",
srcs = [],
hdrs = [
"tpu_compilation_cache_key.h",
],
deps = ["@com_google_absl//absl/types:optional"],
)
cc_library(
name = "tpu_compile_op_support",
srcs = ["tpu_compile_op_support.cc"],
hdrs = ["tpu_compile_op_support.h"],
deps = [
":tpu_compilation_cache_key",
":tpu_compile_c_api_hdrs",
":tpu_compile_proto_cc",
":tpu_executable_info_proto_cc",
"//tensorflow/cc:ops",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:dump",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_module_group",
"//tensorflow/core:framework",
"//tensorflow/core/framework:protos_all_cc",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/stream_executor/tpu:proto_helper",
"//tensorflow/stream_executor/tpu:status_helper",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tpu_compilation_cache_entry",
hdrs = [
"tpu_compilation_cache_entry.h",
],
deps = [
":tpu_executable_info_proto_cc",
":tpu_program",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core/lib/core:refcount",
],
)
cc_library(
name = "tpu_compilation_cache_lookup",
srcs = ["tpu_compilation_cache_lookup.cc"],
hdrs = [
"tpu_compilation_cache_lookup.h",
],
deps = [
":tpu_compilation_cache_entry",
":tpu_compilation_cache_external",
":tpu_compilation_cache_proto_cc",
"//tensorflow/core/lib/core:refcount",
"//tensorflow/core/platform:status",
"//tensorflow/core/profiler/lib:traceme",
],
)
cc_library(
name = "tpu_mesh_state_c_api",
hdrs = ["tpu_mesh_state_c_api.h"],
)
cc_library(
name = "tpu_mesh_state_interface",
srcs = [],
hdrs = ["tpu_mesh_state_interface.h"],
deps = [
":tpu_compile_c_api_hdrs",
":tpu_mesh_state_c_api",
"//tensorflow/compiler/xla/service",
"//tensorflow/core:framework",
"//tensorflow/core/platform:errors",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/core/tpu:tpu_config_c_api",
],
)
cc_library(
name = "tpu_program",
srcs = ["tpu_program.cc"],
hdrs = ["tpu_program.h"],
deps = [
":tpu_compile_c_api_hdrs",
":tpu_compile_op_support",
":tpu_compile_proto_cc",
":tpu_executable_info_proto_cc",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:hlo_module_group",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/stream_executor/tpu:proto_helper",
"//tensorflow/stream_executor/tpu:status_helper",
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "tpu_compilation_cache_external",
srcs = ["tpu_compilation_cache_external.cc"],
hdrs = [
"tpu_compilation_cache_external.h",
],
deps = [
":tpu_compilation_cache_entry",
":tpu_compilation_cache_key",
":tpu_compilation_cache_metrics", # buildcleaner: keep
":tpu_compilation_cache_metrics_hdrs",
":tpu_compilation_cache_proto_cc",
":tpu_compile_c_api_hdrs",
":tpu_compile_op_support",
":tpu_mesh_state_interface",
":tpu_program",
":tpu_util",
":trace_util_hdrs",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform:refcount",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "tpu_compilation_cache_metrics_hdrs",
hdrs = ["tpu_compilation_cache_metrics.h"],
deps = [
"//tensorflow/core/platform:types",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "tpu_compilation_cache_metrics",
srcs = ["tpu_compilation_cache_metrics.cc"],
deps = [
":tpu_compilation_cache_metrics_hdrs",
],
)
cc_library(
name = "trace_util_hdrs",
srcs = [],
hdrs = ["trace_util.h"],
deps = [
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "tpu_util_hdrs",
srcs = [],
hdrs = ["tpu_util.h"],
deps = [
":tpu_compilation_cache_key",
"//tensorflow/cc:ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "tpu_util",
srcs = ["tpu_util.cc"],
hdrs = ["tpu_util.h"],
deps = [
":tpu_compilation_cache_key",
"//tensorflow/cc:ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
tf_proto_library_cc(
name = "tpu_compilation_cache_proto",
srcs = ["tpu_compilation_cache.proto"],
cc_api_version = 2,
)

View File

@ -0,0 +1,25 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow.tpu;
// Target type for compilation cache fetch operation.
enum CompilationCacheFetchTarget {
INVALID = 0;
MAIN = 1;
SHARDING = 2;
UNSHARDING = 3;
}

View File

@ -0,0 +1,84 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_program.h"
namespace tensorflow {
namespace tpu {
class CompilationCacheEntry {
public:
explicit CompilationCacheEntry(
std::unique_ptr<const TpuProgram> tpu_program)
: tpu_program_(std::move(tpu_program)) {}
// Constructor for an empty entry.
CompilationCacheEntry()
: tpu_program_(nullptr) {}
const TPUExecutableInfoProto* get_executable_info() const {
return &tpu_program_->executable_info();
}
const TPUHostTransferInfoProto* get_host_transfer_info() const {
return &tpu_program_->host_transfer_info();
}
const xla::HloProto* get_hlo_metadata() const {
return &tpu_program_->hlo_metadata();
}
// TODO(henrytan,jiawenhao): When should we expect more than one
// XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then.
const XLA_TpuProgram* get_tpu_program() const {
CHECK_EQ(tpu_program_->program_count(), 1);
return tpu_program_->tpu_programs()[0];
}
private:
std::unique_ptr<const TpuProgram> tpu_program_;
};
// Base class for a reference to a cached proto. A unique_ptr to a
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
// and ensures the underlying proto is not garbage-collected until the client
// discards the ptr.
class CompilationCacheEntryRef {
public:
virtual ~CompilationCacheEntryRef() = default;
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
// of the CompilationCacheEntryRef.
virtual CompilationCacheEntry get() = 0;
};
// Base class that holds references to compiled protos so that the protos are
// not garbage-collected before being used by execute ops. Use
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
// ref holder object.
class CompilationRefHolder : public ResourceBase {
public:
~CompilationRefHolder() override = default;
};
} // namespace tpu
} // namespace tensorflow
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_

View File

@ -0,0 +1,791 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
#include <string>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/tpu/kernels/tpu_program.h"
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "tensorflow/core/tpu/kernels/trace_util.h"
namespace tensorflow {
namespace tpu {
namespace {
using CompilationEntry = TpuCompilationCacheInterface::CompilationEntry;
int64 get_uid() {
uint64 unsigned_rand = random::New64() & INT64_MAX;
return static_cast<int64>(unsigned_rand);
}
void PopulateEntry(const std::string& key, CompilationEntry* entry,
std::unique_ptr<TpuProgram> tpu_program) {
// Make the unique keys for each cached proto.
for (int i = 0; i < tpu_program->program_count(); ++i) {
entry->proto_key.push_back(ProtoKeyForComputation(key, i));
}
entry->tpu_program = std::move(tpu_program);
entry->initialized = true;
}
std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
if (!key.has_guaranteed_const) {
return key.prefix;
}
return absl::StrCat(key.prefix, "|", key.session_handle, "|",
key.guaranteed_const_fingerprint());
}
// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
// data to compute the fingerprint.
std::string GuaranteedConstFingerprint(
const string& fingerprint_in_metadata,
const OpInputList& guaranteed_constants) {
if (fingerprint_in_metadata.empty()) {
uint64_t fingerprint = 0;
for (const auto& constant : guaranteed_constants) {
fingerprint = TpuCompile_CreateGuaranteedConstFingerprint(
fingerprint, constant.tensor_data().data(),
constant.tensor_data().size());
}
return std::to_string(fingerprint);
} else {
return fingerprint_in_metadata;
}
}
std::string CreateShapePrefix(
const std::vector<tensorflow::TensorShape>& dynamic_shapes) {
std::string shapes_prefix;
for (const TensorShape& shape : dynamic_shapes) {
for (int64 size : shape.dim_sizes()) {
absl::StrAppend(&shapes_prefix, size, ",");
}
absl::StrAppend(&shapes_prefix, ";");
}
return shapes_prefix;
}
// Include compilation configurations of the arguments that are not captured
// by the called graph.
std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
std::string config_prefix;
for (const auto& arg : metadata.args()) {
if (arg.is_same_data_across_replicas()) {
absl::StrAppend(&config_prefix, ":s");
// Same.
} else {
// Different.
absl::StrAppend(&config_prefix, ":");
}
if (arg.enable_xla_sharding() ==
tpu::TPUCompileMetadataProto::Arg::ALLOWED) {
// Enabled.
absl::StrAppend(&config_prefix, "e");
}
if (arg.unrestricted_layout()) {
// Unrestricted.
absl::StrAppend(&config_prefix, ":u");
}
absl::StrAppend(&config_prefix, ",type(", arg.dtype(), ")");
if (arg.has_shape()) {
absl::StrAppend(&config_prefix, ",shape(");
for (const auto& dim : arg.shape().dim()) {
absl::StrAppend(&config_prefix, dim.size(), ",");
}
absl::StrAppend(&config_prefix, ")");
}
}
return config_prefix;
}
} // namespace
TpuCompilationCacheInterface::TpuCompilationCacheInterface(
int64_t max_cache_size)
: max_cache_size_(max_cache_size) {
if (max_cache_size < 0) {
LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0";
}
VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
}
TpuCompilationCacheInterface::~TpuCompilationCacheInterface() {
VLOG(1) << "TpuCompilationCacheInterface::~TpuCompilationCacheInterface()";
// A buggy client may be holding onto a reference, or a client might have
// crashed while holding onto a reference. In either case, discard all
// outstanding client references to avoid leaking storage.
for (const auto& entry : entries_by_uid_) {
while (entry.second->external_references > 0) {
TF_CHECK_OK(Release(entry.first));
}
}
while (!entries_by_last_use_.empty()) {
UnloadAndDestroy(MarkOldestEntryForEviction());
}
// By the time the cache is deleted all reference holders should have already
// been deleted, since they were holding references to the cache. So all
// entries should be gone at this point.
CHECK_EQ(cache_store_.size(), 0);
CHECK_EQ(entries_by_uid_.size(), 0);
CHECK_EQ(entries_by_proto_key_.size(), 0);
CHECK_EQ(cache_size_, 0);
CHECK_EQ(marked_for_eviction_size_, 0);
}
std::string TpuCompilationCacheInterface::FindCacheKey(
const TpuCompilationCacheKey& subgraph_key) const {
if (!subgraph_key.has_guaranteed_const) {
return subgraph_key.prefix;
}
auto iter = session_key_map_.find(
strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
if (iter != session_key_map_.end()) {
return iter->second;
}
iter = fingerprint_key_map_.find(strings::StrCat(
subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
if (iter != session_key_map_.end()) {
return iter->second;
}
VLOG(1) << "No matching cache key found for key "
<< ConstructCompilationCacheKey(subgraph_key);
return "";
}
void TpuCompilationCacheInterface::InsertEntry(
const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key,
CompilationEntry* entry) {
entry->parent = this;
entry->subgraph_key = cache_key;
entry->uid = get_uid();
TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
entry->cache_entry_debug_string = subgraph_key.prefix;
VLOG(1) << "Cache Initializing Entry Session Debug "
<< entry->cache_entry_debug_string;
if (!subgraph_key.has_guaranteed_const) {
return;
}
session_key_map_.insert(std::make_pair(
strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle),
cache_key));
fingerprint_key_map_.insert(std::make_pair(
strings::StrCat(subgraph_key.prefix,
subgraph_key.guaranteed_const_fingerprint()),
cache_key));
}
CompilationEntry* TpuCompilationCacheInterface::InitializeEntry(
const string& key,
const std::function<Status(TpuProgram*)>& initialize_program,
const TpuCompilationCacheKey& subgraph_key) {
CompilationEntry* main_entry = new CompilationEntry();
// Add the entry to the cache, with size zero since there are no compiled
// programs in it. Once the subgraph has been compiled,
// UpdateEntryAfterCompilation will be called to potentially mark old entries
// that don't fit any more for eviction.
//
// At this point there is one reference to entry, which is owned by the caller
// who created the entry. A second reference, owned by the cache, will be
// added below since we leave the entry in the 'marked for eviction' state
// here.
InsertEntry(key, subgraph_key, main_entry);
// Initialize the programs outside the lock so that other cache operations
// can proceed during the (potentially lengthy) initialization.
Status initialization_status;
auto tpu_program = absl::make_unique<TpuProgram>();
{
mu_.Unlock();
{
profiler::TraceMe compile_programs_traceme(
"TPU compilation cache compile",
/*level=*/2);
initialization_status = initialize_program(tpu_program.get());
}
mu_.Lock();
}
main_entry->initialization_status = initialization_status;
// Add the entry to the uid index.
auto uid_inserted = entries_by_uid_.insert(
std::pair<int64, CompilationEntry*>(main_entry->uid, main_entry));
CHECK(uid_inserted.second);
if (initialization_status.ok()) {
// Compute the entries total size once all members are initialized.
main_entry->total_size = tpu_program->program_size();
}
// TODO(henrytan): handle sharding/unsharding.
PopulateEntry(key, main_entry, std::move(tpu_program));
for (int64 i = 0; i < main_entry->proto_key.size(); ++i) {
auto entry_inserted = entries_by_proto_key_.insert(
std::pair<string, std::pair<CompilationEntry*, int>>(
main_entry->proto_key[i], std::make_pair(main_entry, i)));
CHECK(entry_inserted.second);
}
// Add the size to marked_for_eviction_size_ since it will be adjusted down
// again when the newly-created entry gets unmarked.
marked_for_eviction_size_ += main_entry->total_size;
return main_entry;
}
/*static*/ TpuCompilationCacheKey
TpuCompilationCacheInterface::CreateCompilationCacheKey(
absl::string_view function_name, uint64 function_library_fingerprint,
absl::string_view mlir_module,
const tensorflow::OpInputList& guaranteed_constants,
const std::vector<tensorflow::TensorShape>& dynamic_shapes,
const tensorflow::tpu::TPUCompileMetadataProto& metadata,
const TpuMeshStateInterface& mesh_state) {
VLOG(1) << "FunctionLibraryFingerprint:" << function_library_fingerprint;
std::string shapes_prefix = CreateShapePrefix(dynamic_shapes);
VLOG(1) << "shapes_prefix = " << shapes_prefix;
std::string config_prefix = CreateConfigPrefix(metadata);
VLOG(1) << "config_prefix = " << config_prefix;
std::vector<int32_t> flattened_device_ids;
if (metadata.has_device_assignment()) {
for (const auto& device :
metadata.device_assignment().computation_devices()) {
flattened_device_ids.insert(flattened_device_ids.end(),
device.replica_device_ids().begin(),
device.replica_device_ids().end());
}
}
// TODO(henrytan): return the debug_string.
const char* prefix =
TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty{
config_prefix.data(),
shapes_prefix.data(),
function_name.data(),
mlir_module.data(),
flattened_device_ids.data(),
flattened_device_ids.size(),
guaranteed_constants.size(),
function_library_fingerprint,
metadata.num_cores_per_replica(),
metadata.num_replicas(),
mesh_state.data(),
});
auto buffer_cleanup = gtl::MakeCleanup([prefix]() { delete[] prefix; });
TpuCompilationCacheKey key;
key.prefix = prefix;
// Guaranteed constants can be different across sessions. Use session_handle
// and guaranteed_const fingerprint to guarantee no collision.
if (guaranteed_constants.size() > 0) {
key.has_guaranteed_const = true;
key.session_handle = metadata.session_handle();
// Both `metadata` and `guaranteed_constants` lifetime are captured by
// reference based on the assumption that these variables lifetime is
// managed through the `TPUCompileOpKernelImpl` that outlives the
// lifetime of the compilation cache lookups.
string fingerprint;
key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
fingerprint]() mutable {
if (fingerprint.empty()) {
fingerprint = GuaranteedConstFingerprint(
metadata.guaranteed_const_fingerprint(), guaranteed_constants);
}
return fingerprint;
};
}
return key;
}
TpuCompilationRefHolder* TpuCompilationCacheInterface::MakePerStepRefHolder() {
return new RefHolder(this);
}
Status TpuCompilationCacheInterface::MarkEntryForEviction(int64 subgraph_uid) {
profiler::TraceMe key_release_traceme(
"TPU compilation cache possibly evict uid",
/*level=*/2);
CompilationEntry* deleted_entry = nullptr;
{
absl::MutexLock lock(&mu_);
auto iter = entries_by_uid_.find(subgraph_uid);
if (iter == entries_by_uid_.end()) {
// If already evicted, return ok.
return Status::OK();
}
// Mark entry for eviction.
CompilationEntry* subgraph_to_evict = iter->second;
// If there are external references, should not use this API.
if (subgraph_to_evict->external_references != 0) {
return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
" external_references greater than zero. Should "
"use TpuCompilationCache::Release.");
}
VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction";
entries_by_last_use_.erase(subgraph_to_evict->last_use);
cache_size_ -= subgraph_to_evict->total_size;
marked_for_eviction_size_ += subgraph_to_evict->total_size;
// Evict if refcount exactly one, otherwise only discard cache's reference
// to the entry while the actual eviction will happen when refholder's
// references go away.
deleted_entry = DiscardEntryRef(subgraph_to_evict);
VLOG(1) << "After possibly evicting entry " << subgraph_uid
<< " refs cache is " << cache_store_.size() << " entries ("
<< cache_size_ + marked_for_eviction_size_
<< " bytes), marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
}
// Unload from device cache if entry is evicted from host cache.
UnloadAndDestroy(deleted_entry);
return Status::OK();
}
Status TpuCompilationCacheInterface::Release(int64 subgraph_uid) {
profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
/*level=*/2);
CompilationEntry* deleted_entry = nullptr;
{
absl::MutexLock lock(&mu_);
auto iter = entries_by_uid_.find(subgraph_uid);
if (iter == entries_by_uid_.end()) {
return errors::NotFound("No cache entry found for uid ", subgraph_uid);
}
CHECK_GT(iter->second->external_references, 0);
--iter->second->external_references;
deleted_entry = DiscardEntryRef(iter->second);
VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
<< cache_store_.size() << " entries ("
<< cache_size_ + marked_for_eviction_size_
<< " bytes), marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
}
UnloadAndDestroy(deleted_entry);
return Status::OK();
}
void TpuCompilationCacheInterface::UnloadAndDestroy(CompilationEntry* entry) {
if (!entry) return;
CHECK(entry->RefCountIsOne());
entry->tpu_program->UnloadAndDestroyPrograms();
entry->Unref();
}
size_t TpuCompilationCacheInterface::RemoveEntry(const string& key) {
auto erased = cache_store_.erase(key);
TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
auto parsed_key_or_status = ParseCompilationCacheKey(key);
CHECK(parsed_key_or_status.status().ok());
const TpuCompilationCacheKey parsed_key =
parsed_key_or_status.ConsumeValueOrDie();
if (!parsed_key.has_guaranteed_const) {
return erased;
}
session_key_map_.erase(
strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
fingerprint_key_map_.erase(strings::StrCat(
parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
return erased;
}
ABSL_MUST_USE_RESULT CompilationEntry*
TpuCompilationCacheInterface::DiscardEntryRef(CompilationEntry* entry) {
if (entry->RefCountIsOne()) {
// The last reference to this entry is going away, so really delete it from
// the cache in such a way that it can't be restored by being looked up
// again.
// Sanity-check that it has been marked for eviction.
CHECK(entries_by_last_use_.find(entry->last_use) ==
entries_by_last_use_.end());
// Update the counter tracking how much space is taken up by entries that
// are marked for eviction.
marked_for_eviction_size_ -= entry->total_size;
// Remove the entry from the cache.
auto erased = RemoveEntry(entry->subgraph_key);
if (erased == 0) {
LOG(FATAL) << "Tried to discard nonexistent cache entry";
}
erased = entries_by_uid_.erase(entry->uid);
CHECK_EQ(erased, 1);
for (const string& key : entry->proto_key) {
erased = entries_by_proto_key_.erase(key);
CHECK_EQ(erased, 1);
}
// The actual deletion will happen outside the lock in UnloadAndDestroy().
return entry;
}
entry->Unref();
return nullptr;
}
void TpuCompilationCacheInterface::DiscardEntryRefs(
gtl::ArraySlice<CompilationEntry*> entries) {
std::vector<CompilationEntry*> removed_entries;
{
absl::MutexLock lock(&mu_);
for (auto entry : entries) {
removed_entries.push_back(DiscardEntryRef(entry));
}
VLOG(1) << "After discarding entry refs cache is " << cache_store_.size()
<< " entries (" << cache_size_ + marked_for_eviction_size_
<< " bytes), marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
}
for (auto removed_entry : removed_entries) {
UnloadAndDestroy(removed_entry);
}
}
ABSL_MUST_USE_RESULT CompilationEntry*
TpuCompilationCacheInterface::MarkOldestEntryForEviction() {
CompilationEntry* entry_to_mark = entries_by_last_use_.begin()->second;
VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction";
entries_by_last_use_.erase(entry_to_mark->last_use);
cache_size_ -= entry_to_mark->total_size;
marked_for_eviction_size_ += entry_to_mark->total_size;
// Discard the cache's reference to entry. If steps are holding onto
// references to entry it won't be deleted until the last step holding it
// completes. It stays in the cache in the meantime and can be resurrected
// by a call to CompileIfKeyAbsent if that occurs before the last reference
// expires.
return DiscardEntryRef(entry_to_mark);
}
void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
CompilationEntry* entry, std::vector<CompilationEntry*>* removed_entries) {
// The entry was previously marked for eviction (or is newly created) so
// unmark it. Add a reference (owned by the cache), update the cache size, and
// mark something old for eviction if necessary.
entry->Ref();
marked_for_eviction_size_ -= entry->total_size;
cache_size_ += entry->total_size;
// Mark the least-recently-used non-marked entry for eviction. Never mark the
// most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
// which means there's only one entry not already marked for eviction), so
// that an entry persists in the cache even if it is larger than the allocated
// cache size.
while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
if (auto entry_to_evict = MarkOldestEntryForEviction()) {
removed_entries->push_back(entry_to_evict);
}
}
}
Status TpuCompilationCacheInterface::ToSubEntryRef(
CompilationCacheEntryRef* entry,
CompilationCacheFetchTarget fetch_target) const {
return static_cast<EntryRefImpl*>(entry)->ToSubEntryRef(fetch_target);
}
TpuCompilationCacheInterface::EntryRefImpl::EntryRefImpl(
TpuCompilationCacheInterface* parent, CompilationEntry* entry, int index)
: parent_(parent), entry_(entry), index_(index) {
if (entry_ == nullptr) {
return;
}
if (entry_->main_entry == nullptr) {
entry_->Ref();
} else {
// This is a sharding/unsharding entry nested in a main entry. Only refcount
// the main entry.
entry_->main_entry->Ref();
}
}
TpuCompilationCacheInterface::EntryRefImpl::~EntryRefImpl() {
if (entry_ == nullptr) {
return;
}
if (entry_->main_entry == nullptr) {
parent_->DiscardEntryRefs({entry_});
} else {
parent_->DiscardEntryRefs({entry_->main_entry});
}
}
CompilationCacheEntry TpuCompilationCacheInterface::EntryRefImpl::get() {
if (entry_ == nullptr) {
// Create an empty entry if the entry is nullptr. This corresponds to
// non-existing sharding/unsharding entries.
return CompilationCacheEntry();
}
return CompilationCacheEntry(std::move(entry_->tpu_program));
}
Status TpuCompilationCacheInterface::EntryRefImpl::ToSubEntryRef(
CompilationCacheFetchTarget fetch_target) {
CompilationEntry* target = nullptr;
switch (fetch_target) {
case CompilationCacheFetchTarget::MAIN:
target = entry_;
break;
case CompilationCacheFetchTarget::SHARDING:
target = entry_->sharding_entry.get();
break;
case CompilationCacheFetchTarget::UNSHARDING:
target = entry_->unsharding_entry.get();
break;
default:
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
}
if (target == nullptr) {
// Cache entry does not have an unsharding subentry. Unref and replace
// with nullptr.
parent_->DiscardEntryRefs({entry_});
}
// Otherwise, since the refcount is always on the main entry, we don't need
// ref/unref.
entry_ = target;
return Status::OK();
}
Status TpuCompilationCacheInterface::Lookup(
int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry) {
entry->reset();
profiler::TraceMe proto_lookup_traceme(
"TPU compilation cache proto lookup by uid",
/*level=*/2);
absl::MutexLock lock(&mu_);
const auto iter = entries_by_uid_.find(uid);
if (iter == entries_by_uid_.end()) {
return errors::NotFound("No subgraph found for uid ", uid);
}
CompilationEntry* cache_entry = iter->second;
if (proto_index < 0 ||
proto_index >= cache_entry->tpu_program->program_size()) {
return errors::NotFound("No proto found for core index ", proto_index,
" in subgraph with uid ", uid);
}
*entry = std::unique_ptr<CompilationCacheEntryRef>(
new EntryRefImpl(this, cache_entry, proto_index));
return Status::OK();
}
Status TpuCompilationCacheInterface::Lookup(
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry) {
entry->reset();
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
/*level=*/2);
absl::MutexLock lock(&mu_);
const auto iter = entries_by_proto_key_.find(proto_key);
if (iter == entries_by_proto_key_.end()) {
return errors::NotFound("No proto found for key ", proto_key);
}
CompilationEntry* cache_entry = iter->second.first;
int proto_index = iter->second.second;
*entry = std::unique_ptr<CompilationCacheEntryRef>(
new EntryRefImpl(this, cache_entry, proto_index));
return Status::OK();
}
Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<CompilationEntry*>* removed_entries,
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
const std::function<Status(TpuProgram*)>& compile_function) {
profiler::TraceMe subgraph_lookup_traceme(
"TPU compilation cache subgraph lookup",
/*level=*/2);
// NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
// for the lifetime of the object, see InitializeEntry() call below.
absl::MutexLock lock(&mu_);
std::string cache_key = FindCacheKey(subgraph_key);
auto iter = cache_store_.find(cache_key);
bool is_new_key = iter == cache_store_.end();
const std::string session_name = SessionNameFromMetadata(session_metadata);
CompilationEntry* entry = nullptr;
if (is_new_key) {
cache_key = ConstructCompilationCacheKey(subgraph_key);
TpuCompilationCacheMetrics::IncrementCacheLookupCount(
/*is_cache_hit=*/false, session_name);
const string msg =
strings::StrCat("TPU host compilation cache miss: cache_key(",
cache_key, "), session_name(", session_name, ")");
TRACESTRING(msg);
LOG(INFO) << msg;
// Check if caller has disabled compilation. Set using
// internal::ScopedTpuCompileDisabler.
if (!IsTpuCompilationEnabled()) {
const string error_msg = strings::StrCat(
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
"disabled, session_name(",
session_name, ") Debug String: ", subgraph_key.debug_string);
if (VLOG_IS_ON(2)) {
VLOG(2) << "Cache Missed. Current cache entries: ";
for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) {
// TODO(henrytan): add DebugKey as cache_entry_debug_string to
// TpuCompilationCacheKey.
VLOG(2) << "Cache Debug Info: ";
VLOG(2) << it->second->cache_entry_debug_string;
}
}
LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
return errors::NotFound(error_msg);
}
// The single ref on the newly-created entry is owned by the caller.
VLOG(1) << "Before adding new entry for key " << cache_key
<< " with session_name( " << session_name << ");"
<< "; cache is " << cache_store_.size() << " entries ("
<< cache_size_ + marked_for_eviction_size_ << " bytes), "
<< " marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
// Note that InitializeEntry() will Release/Reacquire mu_.
entry = InitializeEntry(cache_key, compile_function, subgraph_key);
TRACELITERAL("TPU host compilation cache: compilation done.");
LOG(INFO) << strings::StrCat(
"TPU host compilation cache: compilation done for cache_key(",
cache_key, "), session_name(", session_name, ")");
// If session_name is present, log some additional stats related to HBM
// here, so that they can be associated directly to the session.
if (!session_name.empty()) {
entry->tpu_program->LogProgramMemorySummary();
}
} else {
TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name);
const string msg =
strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
"), session_name(", session_name, ")");
TRACESTRING(msg);
VLOG(1) << msg;
VLOG(1) << "Before refreshing entry for key " << cache_key
<< " with session_name( " << session_name << "); cache is "
<< cache_store_.size() << " entries ("
<< cache_size_ + marked_for_eviction_size_ << " bytes), "
<< " marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
entry = iter->second;
// Make a new reference that is owned by the caller.
entry->Ref();
// Block if necessary until the subgraph has been initialized.
mu_.Await(absl::Condition(
+[](CompilationEntry* e) { return e->initialized; }, entry));
}
// Let the caller know the uid of the entry.
*uid = entry->uid;
// Let the caller know the keys for each of the cached protos.
*proto_key = entry->proto_key;
*may_modify_variables = entry->tpu_program->may_modify_variables();
*hlo_metadata = entry->hlo_metadata;
// If the caller didn't supply a per_step_ref_holder then the caller is going
// to manually release the reference later via a call to Release().
if (per_step_ref_holder == nullptr) {
++entry->external_references;
} else {
// The caller wants its reference to be handed off to a per-step holder that
// will discard the reference when the step completes.
RefHolder* cast_ref_holder = static_cast<RefHolder*>(per_step_ref_holder);
TF_RET_CHECK(cast_ref_holder != nullptr);
cast_ref_holder->AddRef(entry);
}
// Remove the old LRU-table entry if it wasn't already marked for eviction.
auto erased = entries_by_last_use_.erase(entry->last_use);
// Update the LRU table indicating this entry is the most recently used.
entry->last_use = use_counter_++;
entries_by_last_use_[entry->last_use] = entry;
if (erased == 0) {
// The entry had been marked for eviction, or is newly created.
LookupEntryMarkedForEviction(entry, removed_entries);
}
// Log a little more verbosely when a key is added.
if (VLOG_IS_ON(1) || is_new_key) {
LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
<< " entry for key " << cache_key << " with session_name "
<< session_name << " cache is " << cache_store_.size()
<< " entries (" << cache_size_ + marked_for_eviction_size_
<< " bytes), "
<< " marked for eviction "
<< (cache_store_.size() - entries_by_last_use_.size())
<< " entries (" << marked_for_eviction_size_ << " bytes).";
}
return entry->initialization_status;
}
tensorflow::Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
const TpuCompilationCacheKey& cache_key,
const tensorflow::SessionMetadata* session_metadata,
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
const std::function<tensorflow::Status(TpuProgram*)>& compile_function) {
std::vector<CompilationEntry*> removed_entries;
auto status = CompileIfKeyAbsentHelper(
cache_key, session_metadata, per_step_ref_holder, uid, proto_key,
may_modify_variables, &removed_entries, hlo_metadata, compile_function);
for (auto entry : removed_entries) {
UnloadAndDestroy(entry);
}
return status;
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,394 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
#include <functional>
#include <memory>
#include <unordered_map>
#include <vector>
#include "absl/container/node_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_program.h"
namespace tensorflow {
namespace tpu {
const char kCompilationCacheResourceName[] = "tpu_compilation_cache";
const char kCompilationCacheUnloaderResourceName[] =
"tpu_compilation_cache_unloader";
// Base class that holds references to compiled protos so that the protos are
// not garbage-collected before being used by execute ops. Use
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
// ref holder object.
class TpuCompilationRefHolder : public ResourceBase {
public:
~TpuCompilationRefHolder() override = default;
};
class TpuCompilationCacheInterface : public ResourceBase {
public:
using Status = ::stream_executor::port::Status;
// An entry in the compilation cache. The entry is deleted once it has been
// marked for eviction from the cache _and_ all steps that use it have
// completed. When the entry is first created, it is uninitialized and a
// client-supplied compilation function is run outside the cache's lock to
// generate the programs to be stored in the entry. Any other client that
// requests the entry will block until it has been initialized. Each entry has
// a last_use value that set from a monotonically-increasing counter in the
// cache whenever the entry is referenced. When the cache becomes full,
// entries are marked for eviction in LRU order.
//
// The bridge can request XLA to generate separate sharding and unsharding
// programs along with the main program; we use nested fields sharding_entry,
// unsharding_entry to store them under the main entry, and these two fields
// must be either both present or both absent. They have a back pointer
// main_entry to refer to the main program. These nested entries share the
// same cache key and the same lifetime as the main entry, so we use the
// refcount on the main entry to track the access to any of them.
// /-------------------------------\
// v \
// main_entry (refcount) -> sharding_entry -> main_entry
// ^ \
// | \-> unsharding_entry -> main_entry
// \--------------------------------------/
struct CompilationEntry : public core::RefCounted {
TpuCompilationCacheInterface* parent = nullptr; // Not owned.
bool initialized = false;
// The Status returned by the compilation function when the entry is
// initialized. This status will be returned to any client that requests the
// entry.
Status initialization_status;
// The uid describing this entry.
int64 uid;
std::vector<string> proto_key;
// Counter to keep track of LRU entries for the eviction policy.
int64 last_use = -1;
// The unique key describing this entry.
std::string subgraph_key;
// Entries representing the associated sharding and unsharding programs,
// which share the same life time of the owning main entry, so we always use
// the main entry's ref count.
std::unique_ptr<CompilationEntry> sharding_entry;
std::unique_ptr<CompilationEntry> unsharding_entry;
// The number of 'external' client-held references to the entry.
int external_references = 0;
std::vector<std::shared_ptr<const xla::HloProto>> hlo_metadata;
// The sum of the SpaceUsed of each of the elements of programs; an estimate
// of how much RAM the entry consumes, used to determine when entries must
// be marked for eviction.
int64 total_size = 0;
// Only used for the nested sharding/unsharding entries to point to the
// owning main entry.
CompilationEntry* main_entry = nullptr;
// Debug info in case we miss.
string cache_entry_debug_string;
// Compiled Tpu program.
std::unique_ptr<TpuProgram> tpu_program;
};
explicit TpuCompilationCacheInterface(int64_t max_cache_size);
~TpuCompilationCacheInterface() override;
TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete;
TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&)
= delete;
Status CompileIfKeyAbsent(
const TpuCompilationCacheKey& cache_key,
const SessionMetadata* session_metadata,
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
const std::function<tensorflow::Status(TpuProgram*)>& compile_function);
static TpuCompilationCacheKey CreateCompilationCacheKey(
absl::string_view function_name, uint64 function_library_fingerprint,
absl::string_view mlir_module,
const tensorflow::OpInputList& guaranteed_constants,
const std::vector<tensorflow::TensorShape>& dynamic_shapes,
const tensorflow::tpu::TPUCompileMetadataProto& metadata,
const TpuMeshStateInterface& mesh_state);
string DebugString() const override { return "TpuCompilationCacheInterface"; }
// Makes a reference holder for this cache, that can be stored in the per-step
// resource manager and will ensure that compiled entries persist until the
// end of a step.
TpuCompilationRefHolder* MakePerStepRefHolder();
// Differences between MarkEntryForEviction and Release:
// There are two modes of managing cache entries:
// 1) LRU eviction + pinning; 2) manual.
// We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
// Otherwise it is manual mode (mainly used by XRT).
// MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
// entries when callers know that they do not need them anymore.
// Release should only be used in mode 2) to explicitly remove an entry.
// Mark the entry indexed by `subgraph_uid` for eviction. This should only be
// called if per_step_ref_holder was NOT nullptr in the corresponding call to
// CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
// subgraph_uid).
Status MarkEntryForEviction(int64 subgraph_uid);
// Manually discards a reference to the compiled subgraph. This should only be
// called if per_step_ref_holder was nullptr in the corresponding call to
// CompileIfKeyAbsent(subgraph_key, ...).
Status Release(int64 subgraph_uid);
// Looks up an executable corresponding to the model-parallel core index of
// the subgraph represented by key. On success a pointer to an EntryRef
// holding the program is returned in entry.
Status Lookup(const string& proto_key,
std::unique_ptr<CompilationCacheEntryRef>* entry);
// Looks up an executable corresponding to the model-parallel core index of
// the subgraph represented by uid. On success a pointer to an EntryRef
// holding the program is returned in entry.
Status Lookup(int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry);
// Mutates the main entry ref to point to the entry's subentry
// (for sharding/unsharding) or main entry (unchanged) representing the
// fetch target. The entry ref needs to point to the main entry before this
// call.
//
// If the requested subentry does not exist, the ref will point to a nullptr
// entry.
Status ToSubEntryRef(CompilationCacheEntryRef* entry,
CompilationCacheFetchTarget fetch_target) const;
private:
// Wrapper for a cache entry that holds a reference to the entry until the
// wrapper is deleted. This wrapper is the concrete type of
// CompilationCacheEntryRef returned by Lookup.
class EntryRefImpl : public CompilationCacheEntryRef {
public:
EntryRefImpl(TpuCompilationCacheInterface* parent, CompilationEntry* entry,
int index);
~EntryRefImpl() override;
CompilationCacheEntry get() override;
// Mutates this ref to point to the entry's subentry (for
// sharding/unsharding) or main entry (unchanged) as specified by
// fetch_target. The refcount is kept unchanged, since we only track the
// refcount of the main entry. The entry ref needs to point to the main
// entry before this call.
//
// If the requested subentry does not exist, the ref will point to a nullptr
// entry, and the original entry will be unref'ed.
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
private:
TpuCompilationCacheInterface* parent_; // Not owned.
// A reference to entry_ is acquired in the constructor and released via
// parent->DiscardEntryRefs in the destructor.
CompilationEntry* entry_;
// The program in entry_ that is returned by the get method.
int index_;
};
// Private implementation of the generic CompilationRefHolder that knows about
// CompiledSubgraph entries.
class RefHolder : public TpuCompilationRefHolder {
public:
explicit RefHolder(TpuCompilationCacheInterface* parent) : parent_(parent) {
parent_->Ref();
}
~RefHolder() override {
// Release our reference to the parent.
parent_->Unref();
}
// Adds entry to the list of entries that will be released when the
// RefHolder is destroyed. Each entry is released via a call to
// parent_->DiscardEntryRefs.
void AddRef(CompilationEntry* entry) {
entries_.push_back(entry);
}
string DebugString() const override {
return "TpuCompilationCacheInterface::RefHolder";
}
private:
TpuCompilationCacheInterface* parent_; // Not owned.
std::vector<CompilationEntry*> entries_;
};
// The bulk of implementation of CompileIfKeyAbsent() with the exception
// of unloading programs that corresponds to possibly removed cache
// entries. The split helps to manage locking since we prefer to perform
// unloading without holding extra locks.
Status CompileIfKeyAbsentHelper(
const TpuCompilationCacheKey& subgraph_key,
const SessionMetadata* session_metadata,
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
std::vector<CompilationEntry*>* removed_entries,
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
const std::function<Status(TpuProgram*)>& compile_function);
// This is called by the cache when entry is marked for eviction; by
// a RefHolder (via DiscardEntryRefs) when a step completes; and by
// an EntryRefImpl when it is destroyed. Releases one reference to entry
// if more than 1 remains. If only one reference is left, the entry is removed
// from cache_ and is returned to the caller; which must eventually call
// UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
// to avoid holding the lock during program unloading.
ABSL_MUST_USE_RESULT CompilationEntry* DiscardEntryRef(
CompilationEntry* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Convenience method called by ~RefHolder without mu_ held. Calls
// DiscardEntryRef on every element of entries.
void DiscardEntryRefs(
gtl::ArraySlice<CompilationEntry*> entries);
// Marks the oldest unmarked entry for eviction. Requires that there is at
// least one such entry. In case the evicted entry had only 1 reference it
// is removed from the cache and returned to the caller which must eventually
// call UnloadAndDestroy.
CompilationEntry* MarkOldestEntryForEviction()
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Updates datastructures to indicate that entry, which had been marked for
// eviction, has been looked up. This is called by CompileIfKeyAbsent when an
// entry is newly created, or an entry that has been marked for eviction but
// not yet evicted is looked up.
//
// First the entry is unmarked for eviction, i.e. the cache gains a reference
// to entry, entry's last_use field is set to be the most recent value of
// use_counter_ and entries_by_last_use_ is updated accordingly.
//
// Next, the size of the cache is examined to see if any other entries need to
// be marked for eviction now that entry has been unmarked. While the total
// size of unmarked cached entries is greater than max_cache_size_, entries
// are marked for eviction in LRU order. The most recently used entry is never
// marked for eviction, so an entry larger than the max cache size will remain
// in the cache until it is replaced by something else. In case some entries
// actually were removed from the cache, they are a returned to the caller via
// removed_entries. The caller must eventually delete them by calling
// UnloadAndDestroy.
void LookupEntryMarkedForEviction(
CompilationEntry* entry, std::vector<CompilationEntry*>* removed_entries)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Removes the entry with given key from cache.
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Inserts the given key and entry to cache.
void InsertEntry(const std::string& key,
const TpuCompilationCacheKey& subgraph_key,
CompilationEntry* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns the cache key matching given subgraph_key.
std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Creates a new entry by running initialize_programs and places it in the
// cache to be looked up by key. The new entry is in the 'marked for eviction'
// state (not present in entries_by_last_use_) and the caller is expected to
// call LookupEntryMarkedForEviction after InitializeEntry.
//
// **InitializeEntry releases mu_ during the call to initialize_programs.**
CompilationEntry* InitializeEntry(
const string& key,
const std::function<Status(TpuProgram*)>& initialize_program,
const TpuCompilationCacheKey& subgraph_key)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Unloads the program associated with the entry from all local devices
// and deletes the entry itself. It is assumed no one else has a reference
// to it and all related keys had already been removed from the cache.
// The call can perform device IO so no locks should be held while calling it.
void UnloadAndDestroy(CompilationEntry* entry) ABSL_LOCKS_EXCLUDED(mu_);
// The maximum size of entries that are stored in the cache before entries are
// marked for eviction.
const int64 max_cache_size_;
mutable absl::Mutex mu_;
// The total size of entries that are stored and not marked for eviction.
int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
// The total size of entries that are marked for eviction.
int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
// The value to assign to the last_use field of the next entry that is looked
// up.
int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
// session_key_map_ and fingerprint_key_map_ are used for looking up the
// cache_ key matching a given subgraph key. When doing a lookup, check
// session_key_map_ first to avoid unnecessay fingerprint computation.
// Map from key prefix + session_handle to a cache_ key.
std::unordered_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
// Map from key prefix + fingerprint to a cache_ key.
std::unordered_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
// All the subgraph entries that can be looked up in the cache. An entry is
// marked for eviction iff it is present in cache_ and not in
// entries_by_last_use_.
std::unordered_map<string, CompilationEntry*> cache_store_
ABSL_GUARDED_BY(mu_);
// All the subgraph entries that can be looked up in the cache, indexed by
// uid.
absl::node_hash_map<int64, CompilationEntry*> entries_by_uid_
ABSL_GUARDED_BY(mu_);
// All the protos that can be looked up in the cache, indexed by proto
// key. The value of the map is a subgraph and the index of the proto compiled
// for that subgraph.
std::unordered_map<string, std::pair<CompilationEntry*, int>>
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
// Map from last_use to entry, used to mark entries for eviction in LRU
// order. If an entry's last_use counter is not present as a key in
// entries_by_last_use_ then the entry has been marked for eviction.
std::map<int64, CompilationEntry*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
};
} // namespace tpu
} // namespace tensorflow
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
#include <functional>
#include <string>
#include "absl/types/optional.h"
namespace tensorflow {
namespace tpu {
struct TpuCompilationCacheKey {
// Prefix of the key.
std::string prefix;
// A boolean flag to specify if `guaranteed_const` is used. Guarantee const is
// normally used in TPU inference to avoid re-copying unchanged variables onto
// the TPU device. It promises the value is identical for every execution in
// the same session even if the actual value changes in later executions.
bool has_guaranteed_const = false;
// Unique session identifier. It is set when `has_guaranteed_const` is true.
std::string session_handle;
// Fingerprint of `guaranteed_const` value. It is set when the value of the
// `has_guaranteed_const` is true. Produce the value when necessary.
std::function<std::string()> guaranteed_const_fingerprint;
// A more verbose key for debugging purpose.
std::string debug_string;
explicit TpuCompilationCacheKey() {}
explicit TpuCompilationCacheKey(const std::string& p) : prefix(p) {}
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_

View File

@ -0,0 +1,93 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
namespace tensorflow {
namespace tpu {
namespace {
class CompilationCacheFetchTargetUtility {
public:
CompilationCacheFetchTargetUtility()
: names_({"Invalid", "Main", "Sharding", "Unsharding"}) {}
std::string name(CompilationCacheFetchTarget target) const {
return names_[static_cast<int>(target)];
}
private:
const std::vector<std::string> names_;
};
std::string GetName(CompilationCacheFetchTarget target) {
static const auto* util = new CompilationCacheFetchTargetUtility();
return util->name(target);
}
} // namespace
TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
TpuCompilationCacheInterface* cache)
: cache_(cache) {}
TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
cache_->Unref();
}
Status TpuCompilationCacheLocalLookup::Lookup(
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) {
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
/*level=*/2);
Status s = cache_->Lookup(proto_key, entry);
VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
<< s;
if (!s.ok()) {
return s;
}
s = cache_->ToSubEntryRef(entry->get(), fetch_target);
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
<< s;
return s;
}
Status TpuCompilationCacheLocalLookup::Lookup(
int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) {
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
/*level=*/2);
Status s = cache_->Lookup(uid, proto_index, entry);
VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
<< " in local subgraph cache status " << s;
if (!s.ok()) {
return s;
}
s = cache_->ToSubEntryRef(entry->get(), fetch_target);
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
<< s;
return s;
}
string TpuCompilationCacheLocalLookup::DebugString() const {
return "TpuCompilationCacheLocalLookup";
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,99 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
namespace tensorflow {
namespace tpu {
// Base class allowing Execute Ops to look up ISA protos. Different subclasses
// are used when the execute Op is in the same address space as the compile Op,
// and when they need to communicate over RPC.
class TpuCompilationCacheLookup : public ResourceBase {
public:
~TpuCompilationCacheLookup() override = default;
// Looks up an executable corresponding to the model-parallel core index of
// the subgraph represented by key. On success a wrapper for the proto is
// returned in program. The wrapper is guaranteed to be valid only during the
// execution of the Op requesting the proto.
//
// Only one of the main, sharding, unsharding entries is fetched, as specified
// in fetch_target.
//
// If the compilation does not create sharding/unsharding programs, but the
// fetch_target requests one of them, then after this call
// (*entry)->get().get_executable() will return nullptr.
virtual Status Lookup(const string& proto_key,
std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) = 0;
virtual Status Lookup(const string& proto_key,
std::unique_ptr<CompilationCacheEntryRef>* entry) {
return Lookup(proto_key, std::move(entry),
CompilationCacheFetchTarget::MAIN);
}
// Looks up an executable corresponding to the model-parallel core index of
// the subgraph represented by uid. On success a wrapper for the proto is
// returned in program. The wrapper is guaranteed to be valid only during the
// execution of the Op requesting the proto.
virtual Status Lookup(int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) = 0;
virtual Status Lookup(int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry) {
return Lookup(uid, proto_index, std::move(entry),
CompilationCacheFetchTarget::MAIN);
}
};
// Forward declaration to break cycle dependency graph.
class TpuCompilationCacheInterface;
// Class for looking up ISA protos when the execute and compile Op are in the
// same address space. The proto is simply looked up in the compilation cache,
// without any serialization taking place.
class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
public:
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
~TpuCompilationCacheLocalLookup() override;
Status Lookup(const string& proto_key,
std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) override;
Status Lookup(int64 uid, int proto_index,
std::unique_ptr<CompilationCacheEntryRef>* entry,
CompilationCacheFetchTarget fetch_target) override;
string DebugString() const override;
private:
// The subgraph compilation cache, in the same process address space where the
// lookups are happening.
TpuCompilationCacheInterface* cache_;
};
} // namespace tpu
} // namespace tensorflow
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
namespace tensorflow {
namespace tpu {
/* static */
void TpuCompilationCacheMetrics::IncrementCacheLookupCount(
bool is_cache_hit, absl::string_view session_name) {
// A placeholder for tracking metrics.
}
/* static */
void TpuCompilationCacheMetrics::SetCacheEntryCount(int64 count) {
// A placeholder for tracking metrics.
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace tpu {
// Tracks Tpu compilation cache metrics.
class TpuCompilationCacheMetrics {
public:
// Increments the number of cache lookup count.
static void IncrementCacheLookupCount(bool is_cache_hit,
absl::string_view session_name);
// Sets the total count of cache entries.
static void SetCacheEntryCount(int64 count);
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_

View File

@ -0,0 +1,144 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow.tpu;
import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
import "tensorflow/compiler/xla/service/hlo.proto";
import "tensorflow/compiler/xla/xla_data.proto";
import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
import "tensorflow/core/protobuf/tpu/compile_metadata.proto";
import "tensorflow/core/tpu/kernels/tpu_executable_info.proto";
message PerCoreVariableIndices {
// For each resource variable output, what was the index of the corresponding
// input and was it updated? The indices are sorted by input order.
repeated TPUExecutableInfoProto.UpdateIndexPair variable_indices = 1;
}
message PerCoreArgShapes {
// Argument shapes for each Tpu core.
repeated xla.ShapeProto shapes = 1;
}
message PerCoreOutputShapes {
// Output shapes for each Tpu core.
repeated xla.ShapeProto shapes = 1;
}
message OutputDescriptionProto {
// Type and shape of the output. The shape is the unflattened shape.
// When `type` is DT_RESOURCE, `shape` is the shape of the resource
// variable's value.
tensorflow.DataType type = 1;
tensorflow.TensorShapeProto shape = 2;
// Constant output value, if known to be constant at JIT compilation time.
// 'Tensor' is in host memory.
bool is_constant = 3;
tensorflow.TensorProto constant_value = 4;
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
// the index of the input that contains the resource.
int32 input_index = 5;
// Whether this output is a TensorList.
bool is_tensor_list = 6;
}
// Describes a variable write side effect of the computation.
message ResourceUpdateProto {
// Index of the input that contains the variable resource to write to.
int32 input_index = 1;
// Type and shape of the tensor to be written back.
// The `shape` field has the same meaning as the Argument::shape field.
tensorflow.DataType type = 2;
tensorflow.TensorShapeProto shape = 3;
// Was the value of the variable modified by the computation?
// (Always true, unless `return_updated_values_for_all_resources` is true.)
bool modified = 4;
// If the resource is a TensorArray, the set of gradients read or written.
map<string, bool> tensor_array_gradients_accessed = 5;
}
// Describes the result of a XLA Compiler compilation.
message XlaCompilationResultProto {
// Vector that maps from the parameters of the XLA computation to their
// original argument positions. To handle compile-time constant inputs, the
// parameters to the XLA computation may be a subset of the original
// arguments. The relative ordering of parameters are maintained.
repeated int32 input_mappings = 1;
// Input shapes of the computation. If we are flattening inputs, these are
// the flattened shapes.
repeated xla.ShapeProto xla_input_shapes = 2;
// Output shape in XLA format. The output shape is always a tuple. If we
// are flattening outputs, these are the flattened shapes.
xla.ShapeProto xla_output_shape = 3;
// TensorFlow shapes of outputs, together with the values of any
// constant arguments. Vector indexed by Tensorflow _Retval number,
// containing both constant and non-constant results.
repeated OutputDescriptionProto outputs = 4;
// TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
// matching RecvAtHost/SendFromHost Ops in the outer graph.
tf2xla.HostComputeMetadata host_compute_metadata = 5;
// Resources whose values were updated by the computation, ordered
// by return value position (which is the same as the order the resources
// were passed as arguments). Resource updates follow the non-constant
// results in the outputs of XLA computation.
repeated ResourceUpdateProto resource_updates = 6;
// The XLA computation built from the tensorflow subgraph.
xla.HloModuleProto computation = 7;
}
// TpuAotCompilationRequestProto represents a compilation request for performing
// ahead-of-time (AOT) compilation of XLA Computations into XLA HLO IR.
message TpuAotCompilationRequestProto {
// A set of HLO module built to run concurrently
// across different devices.
xla.HloModuleGroupProto hlo_module_group = 1;
// Compilation metadata.
TPUCompileMetadataProto metadata = 2;
// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
// represents the device ids assigned to a set of replicated computations.
// See xla::DeviceAssignment class comment for more details.
xla.DeviceAssignmentProto device_assignment = 3;
// Per TPU core program arguments shapes.
repeated PerCoreArgShapes per_core_arg_shapes = 4;
// Per TPU core program outputs shapes.
repeated PerCoreOutputShapes per_core_output_shapes = 5;
// Per TPU core information containing what was the index of the corresponding
// input and if whether it was updated. The indices are sorted by input order.
repeated PerCoreVariableIndices per_core_variable_indices = 6;
// XLA compiler compilation result.
XlaCompilationResultProto compilation_result = 7;
}

View File

@ -0,0 +1,119 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
enum TpuCoreTypeEnum {
kTensorCore,
kEmbeddingV1,
kEmbeddingV2,
};
typedef struct XLA_TpuProgram XLA_TpuProgram;
// Property for creating compilation cache key.
struct CompilationCacheKeyProperty {
const char* config_prefix;
const char* shapes_prefix;
const char* function_name;
const char* mlir_module;
const int32_t* device_ids;
size_t device_ids_size;
int32_t guaranteed_constants_size;
uint64_t function_library_fingerprint;
int32_t num_cores_per_replica;
int32_t num_replicas;
const XLA_TpuMeshState* mesh_state;
};
extern "C" {
// Creates a new TPU program.
XLA_TpuProgram* TpuProgram_New();
// Destroys the `tpu_program`.
void TpuProgram_Free(XLA_TpuProgram* tpu_program);
// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
// destroyed, it is in an unusable state.
void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
SE_Status* status);
// Gets TPU program size in bytes from the `tpu_program`.
int64_t TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
// Logs the summary of current memory state snapshot of the `tpu_program`.
bool TpuProgram_LogProgramMemorySummary(const XLA_TpuProgram* tpu_program);
// Gets TPU program executable info from the `tpu_program`.
void TpuProgram_GetExecutableInfo(const XLA_TpuProgram* tpu_program,
TpuSerializedProto* executable_info);
// Gets host transfer info proto.
void TpuProgram_GetHostTransferInfo(
const XLA_TpuProgram* tpu_program,
TpuSerializedProto* host_transfer_info);
// Gets HLO metadata proto.
void TpuProgram_GetHloMetadata(const XLA_TpuProgram* tpu_program,
TpuSerializedProto* hlo_metadata);
// Returns the number of available TPU core count.
int TpuTopology_AvailableCoreCount(const XLA_TpuMeshState* mesh_state,
TpuCoreTypeEnum tpu_core_type);
// Creates a unique compilation cache `key` used for `put` and `get` operations.
// Returned buffer is heap-allocated and must be owned.
const char* TpuCompile_CreateCompilationCacheKey(
CompilationCacheKeyProperty property);
// Creates a guaranteed const fingerprint. Guarantee const is normally used in
// TPU inference to avoid re-copying unchanged variables onto the TPU device.
// It promises the value is identical for every execution in the same session
// even if the actual value changes in later executions.
uint64_t TpuCompile_CreateGuaranteedConstFingerprint(uint64_t fingerprint,
const char* data,
size_t size);
// Checks if whether a TPU compilation is enabled.
bool TpuCompile_IsTpuCompilationEnabled();
// Executes the computations using XLA TPU compiler and returns TPU programs
// ready for execution.
void TpuCompile_CompileAheadOfTime(
TpuSerializedProto aot_compilation_request,
XLA_TpuProgram** tpu_programs[],
size_t* count, SE_Status* status);
// Builds `DeviceAssignment` from `TpuCompileMetadata` serialized proto.
void TpuCompile_BuildXLADeviceAssignment(
TpuSerializedProto serialized_tpu_compile_metadata,
const XLA_TpuMeshState* mesh_state,
TpuSerializedProto* serialized_device_assignment, SE_Status* status);
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
void TpuCompile_ToTpuShapeRepresentation(
TpuSerializedProto serialized_xla_shape, int data_type,
bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape,
SE_Status* status);
} // extern "C"
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
namespace tensorflow {
namespace internal {
namespace {
static bool tpu_compilation_cancellation_terminates_process = true;
static bool tpu_compilation_failure_closes_chips = true;
} // namespace
void SetTpuCompilationCancellationTerminatesProcess(bool b) {
tpu_compilation_cancellation_terminates_process = b;
}
bool TpuCompilationCancellationTerminatesProcess() {
return tpu_compilation_cancellation_terminates_process;
}
void SetTpuCompilationFailureClosesChips(bool value) {
tpu_compilation_failure_closes_chips = value;
}
bool TpuCompilationFailureClosesChips() {
return tpu_compilation_failure_closes_chips;
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
#include <string>
namespace tensorflow {
namespace internal {
// Setter and getter that determine how TPUCompile responds to cancelled
// compilation. By default this is true, meaning cancelled compilation will
// abort the process, since that's the only mechanism we have available.
//
// Setting this to false allows the process to remain alive, and should only be
// used in tests.
void SetTpuCompilationCancellationTerminatesProcess(bool b);
bool TpuCompilationCancellationTerminatesProcess();
// Setter and getter that determine whether TPU compilation failure will cause
// chips to close. By default this is true, it is suitable for training. For
// inference, we never want servers to die and thus chips will keep alive.
// See b/109873767.
void SetTpuCompilationFailureClosesChips(bool value);
bool TpuCompilationFailureClosesChips();
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_

View File

@ -0,0 +1,439 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
namespace tensorflow {
namespace tpu {
using stream_executor::port::Status;
using stream_executor::port::StatusOr;
using xla::ComputationLayout;
using xla::DebugOptions;
using xla::DeviceAssignment;
using xla::HloModuleConfig;
using xla::HloSharding;
using xla::InvalidArgument;
using xla::ProgramShape;
using xla::Shape;
using xla::ShapeTree;
using xla::ShapeUtil;
Status ValidateResultShape(const Shape& client_shape,
const Shape& result_shape) {
TF_RETURN_IF_ERROR(
xla::ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
if (!xla::ShapeUtil::Compatible(client_shape, result_shape)) {
return InvalidArgument(
"Shape used to set computation result layout %s is not compatible "
"with result shape %s",
xla::ShapeUtil::HumanStringWithLayout(client_shape),
xla::ShapeUtil::HumanString(result_shape));
}
return Status::OK();
}
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape, absl::Span<const Shape> argument_shapes,
absl::optional<const Shape> result_layout,
absl::optional<const DeviceAssignment> device_assignment, int replica_count,
int num_partitions, const DebugOptions* debug_options, const int* seed,
const int* launch_id, const bool* alias_passthrough_params,
const xla::FusionConfigCollection* fusion_config_collection,
const std::vector<std::vector<bool>>* fusion_config) {
auto config = absl::make_unique<HloModuleConfig>(program_shape);
ComputationLayout* computation_layout =
config->mutable_entry_computation_layout();
if (program_shape.parameters_size() != argument_shapes.size()) {
return InvalidArgument("computation takes %d parameters, but %u given",
program_shape.parameters_size(),
argument_shapes.size());
}
for (int i = 0; i < argument_shapes.size(); ++i) {
// Verify that shape of arguments matches the shape of the arguments in the
// ProgramShape.
if (!ShapeUtil::Compatible(argument_shapes[i],
program_shape.parameters(i))) {
return InvalidArgument(
"Argument does not match shape of computation parameter %d: want "
"%s, got %s",
i, ShapeUtil::HumanString(program_shape.parameters(i)),
ShapeUtil::HumanString(argument_shapes[i]));
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
argument_shapes[i]));
}
if (result_layout.has_value()) {
TF_RETURN_IF_ERROR(
ValidateResultShape(result_layout.value(), program_shape.result()));
TF_RETURN_IF_ERROR(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
result_layout.value()));
} else {
// If the result layout is not set, then choose the default.
computation_layout->mutable_result_layout()->SetToDefaultLayout();
}
config->set_replica_count(replica_count);
config->set_num_partitions(num_partitions);
if (seed != nullptr) {
config->set_seed(*seed);
}
if (launch_id != nullptr) {
config->set_launch_id(*launch_id);
}
if (debug_options != nullptr) {
config->set_debug_options(*debug_options);
} else {
config->set_debug_options(xla::GetDebugOptionsFromFlags());
}
// TODO(henrytan): set intra_op_parallelism_threads.
// Reference:
// tensorflow/compiler/xla/service/service.cc?l=324.
if (device_assignment.has_value()) {
config->set_static_device_assignment(device_assignment.value());
}
if (alias_passthrough_params != nullptr) {
config->set_alias_passthrough_params(*alias_passthrough_params);
}
if (fusion_config_collection != nullptr && fusion_config != nullptr &&
*fusion_config_collection != xla::FusionConfigCollection::kOff) {
config->set_fusion_config_collection(*fusion_config_collection);
*config->mutable_fusion_config() = *fusion_config;
}
return std::move(config);
}
StatusOr<std::unique_ptr<xla::HloModuleConfig>> CreateModuleConfig(
const xla::ProgramShape& program_shape,
absl::Span<const Shape> argument_shapes,
absl::optional<const Shape> result_layout,
absl::optional<const DeviceAssignment> device_assignment, int replica_count,
int num_partitions, const DebugOptions* debug_options) {
return CreateModuleConfig(program_shape, argument_shapes, result_layout,
device_assignment, replica_count, num_partitions,
debug_options, /*seed=*/nullptr,
/*launch_id=*/nullptr,
/*alias_passthrough_params=*/nullptr,
/*fusion_config_collection=*/nullptr,
/*fusion_config=*/nullptr);
}
ShapeTree<HloSharding> GetSubtree(
const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) {
ShapeTree<HloSharding> element_shape_tree(
xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(),
element_index),
HloSharding::Replicate());
xla::ShapeIndex src_index;
src_index.push_back(element_index);
element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {});
return element_shape_tree;
}
Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding,
int64 device) {
if (shape.IsTuple()) {
ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape);
std::vector<Shape> arg_shapes;
for (int64 i = 0; i < xla::ShapeUtil::TupleElementCount(shape); ++i) {
Shape element_shape = xla::ShapeUtil::GetTupleElementShape(shape, i);
HloSharding element_sharding = tuple_shape_tree.element({i});
if (element_shape.IsTuple()) {
element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i));
}
if (element_sharding.UsesDevice(device)) {
arg_shapes.push_back(
GetPerDeviceShape(element_shape, element_sharding, device));
}
}
return xla::ShapeUtil::MakeTupleShape(arg_shapes);
}
if (sharding.IsTileMaximal()) {
return shape;
}
std::vector<int64> dimensions;
std::vector<int64> offset = sharding.TileOffsetForDevice(shape, device);
std::vector<int64> limit = sharding.TileLimitForDevice(shape, device);
for (int64 i = 0; i < limit.size(); ++i) {
dimensions.push_back(limit[i] - offset[i]);
}
if (shape.has_layout()) {
return xla::ShapeUtil::MakeShapeWithLayout(shape.element_type(), dimensions,
shape.layout().minor_to_major());
}
return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions);
}
Status AddVariableUpdatesToCores(
const TPUCompileMetadataProto& metadata,
const XlaCompiler::CompilationResult& compilation_result,
const std::vector<ShardingAndIndex>& arg_core_mapping,
std::vector<bool>* may_modify_variables,
std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices) {
// Add all variables to the corresponding core.
may_modify_variables->resize(metadata.num_cores_per_replica(), false);
int resource_update_pos = 0;
for (int i = 0; i < metadata.args_size(); ++i) {
const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
if (proto_arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
const auto& sharding = proto_arg.sharding();
bool updated = false;
if (resource_update_pos < compilation_result.resource_updates.size()) {
const XlaCompiler::ResourceUpdate& update =
compilation_result.resource_updates[resource_update_pos];
if (update.input_index == i) {
updated = true;
int pos = compilation_result.outputs.size() + resource_update_pos;
xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
compilation_result.xla_output_shape, pos);
auto add_to_core = [&](int64 core, const xla::Shape& per_core_shape) {
(*per_core_output_shapes)[core].push_back(per_core_shape);
(*may_modify_variables)[core] =
(*may_modify_variables)[core] || update.modified;
};
if (sharding.type() == xla::OpSharding::MAXIMAL) {
add_to_core(sharding.tile_assignment_devices(0), shape);
} else if (sharding.type() == xla::OpSharding::OTHER) {
auto sharding_or =
xla::HloSharding::FromProto(proto_arg.sharding());
TF_RET_CHECK(sharding_or.ok());
for (int64 core : proto_arg.sharding().tile_assignment_devices()) {
xla::Shape per_core_shape =
GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
add_to_core(core, per_core_shape);
}
} else {
TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
for (int64 core = 0; core < metadata.num_cores_per_replica();
++core) {
add_to_core(core, shape);
}
}
++resource_update_pos;
}
}
if (sharding.type() == xla::OpSharding::MAXIMAL) {
(*per_core_variable_indices)[sharding.tile_assignment_devices(0)]
.push_back(
std::pair<int, bool>(arg_core_mapping[i].indices[0], updated));
} else if (sharding.type() == xla::OpSharding::OTHER) {
for (int core : sharding.tile_assignment_devices()) {
(*per_core_variable_indices)[core].push_back(
std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
}
} else {
TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
for (int64 core = 0; core < metadata.num_cores_per_replica(); ++core) {
(*per_core_variable_indices)[core].push_back(
std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
}
}
}
}
return Status::OK();
}
Status ComputeOutputShapesForEachCore(
const tpu::TPUCompileMetadataProto& metadata,
const XlaCompiler::CompilationResult& compilation_result,
std::vector<std::vector<xla::Shape>>* per_core_output_shapes) {
for (int i = 0; i < metadata.retvals_size(); ++i) {
const tpu::TPUCompileMetadataProto::Retval& retval = metadata.retvals(i);
TF_RET_CHECK(!compilation_result.outputs[i].is_constant)
<< "TPU compilation output " << i
<< " has a compile-time constant value. "
"This should never happen.";
xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
compilation_result.xla_output_shape, i);
auto add_shape_to_core = [&](int core, xla::Shape per_core_shape) {
(*per_core_output_shapes)[core].push_back(std::move(per_core_shape));
};
if (retval.sharding().type() == xla::OpSharding::MAXIMAL) {
add_shape_to_core(retval.sharding().tile_assignment_devices(0),
std::move(shape));
} else if (retval.sharding().type() == xla::OpSharding::OTHER) {
auto sharding_or = xla::HloSharding::FromProto(retval.sharding());
TF_RET_CHECK(sharding_or.ok());
for (int64 core : retval.sharding().tile_assignment_devices()) {
xla::Shape per_core_shape =
GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
add_shape_to_core(core, std::move(per_core_shape));
}
} else {
TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED)
<< "Not all of the constant tensors were consumed.";
for (int core = 0; core < per_core_output_shapes->size(); ++core) {
add_shape_to_core(core, shape);
}
}
}
return Status::OK();
}
Status CreateHloModules(
const TPUCompileMetadataProto& metadata,
const tensorflow::XlaCompiler::CompilationResult& compilation_result,
const absl::optional<xla::DeviceAssignment>& device_assignment,
std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules) {
TF_RET_CHECK(
compilation_result.computation->proto().has_host_program_shape());
auto debug_options = xla::DebugOptions();
debug_options.set_xla_step_marker_location(metadata.step_marker_location());
TF_ASSIGN_OR_RETURN(
std::unique_ptr<xla::HloModuleConfig> module_config,
CreateModuleConfig(
xla::ProgramShape(
compilation_result.computation->proto().host_program_shape()),
compilation_result.xla_input_shapes,
compilation_result.xla_output_shape, device_assignment,
metadata.num_replicas(), metadata.num_cores_per_replica(),
&debug_options));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<xla::HloModule> hlo_module,
xla::HloModule::CreateFromProto(compilation_result.computation->proto(),
*module_config));
DumpHloModuleIfEnabled(*hlo_module, "before_optimizations");
hlo_modules->push_back(std::move(hlo_module));
return Status::OK();
}
XlaCompilationResultProto SerializeCompilationResult(
const XlaCompiler::CompilationResult& compilation_result) {
XlaCompilationResultProto compilation_result_proto;
for (int input_mapping : compilation_result.input_mapping) {
compilation_result_proto.add_input_mappings(input_mapping);
}
for (const Shape& input_shape : compilation_result.xla_input_shapes) {
*(compilation_result_proto.add_xla_input_shapes()) = input_shape.ToProto();
}
*(compilation_result_proto.mutable_xla_output_shape()) =
compilation_result.xla_output_shape.ToProto();
for (const XlaCompiler::OutputDescription& output_description :
compilation_result.outputs) {
auto* new_output = compilation_result_proto.add_outputs();
new_output->set_type(output_description.type);
output_description.shape.AsProto(new_output->mutable_shape());
new_output->set_is_constant(output_description.is_constant);
output_description.constant_value.AsProtoField(
new_output->mutable_constant_value());
new_output->set_input_index(output_description.input_index);
new_output->set_is_tensor_list(output_description.is_tensor_list);
}
*compilation_result_proto.mutable_host_compute_metadata() =
compilation_result.host_compute_metadata;
for (const XlaCompiler::ResourceUpdate& resource_update :
compilation_result.resource_updates) {
auto* new_resource_update = compilation_result_proto.add_resource_updates();
new_resource_update->set_input_index(resource_update.input_index);
new_resource_update->set_type(resource_update.type);
resource_update.shape.AsProto(new_resource_update->mutable_shape());
new_resource_update->set_modified(resource_update.modified);
for (const std::string& gradient_access :
resource_update.tensor_array_gradients_accessed) {
new_resource_update->mutable_tensor_array_gradients_accessed()->insert(
{gradient_access, true});
}
}
if (compilation_result.computation != nullptr) {
*compilation_result_proto.mutable_computation() =
compilation_result.computation->proto();
}
return compilation_result_proto;
}
StatusOr<TpuAotCompilationRequestProto> CreateTpuAotCompilationRequest(
const xla::HloModuleGroup& module_group,
const XlaCompiler::CompilationResult& compilation_result,
const TPUCompileMetadataProto& metadata,
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
const std::vector<std::vector<std::pair<int, bool>>>&
per_core_variable_indices,
const absl::optional<xla::DeviceAssignment>& device_assignment) {
VLOG(1) << "CreateTpuAotCompilationRequest.";
TpuAotCompilationRequestProto aot_request;
*(aot_request.mutable_hlo_module_group()) = module_group.ToProto();
*(aot_request.mutable_metadata()) = metadata;
if (device_assignment.has_value()) {
xla::DeviceAssignmentProto device_assignment_proto;
Status status = device_assignment->Serialize(&device_assignment_proto);
if (!status.ok()) {
return status;
}
*(aot_request.mutable_device_assignment()) = device_assignment_proto;
}
for (const auto& arg_shapes : per_core_arg_shapes) {
auto* new_shape_list = aot_request.add_per_core_arg_shapes();
for (const auto& arg_shape : arg_shapes) {
*new_shape_list->add_shapes() = arg_shape.ToProto();
}
}
for (const auto& output_shapes : per_core_output_shapes) {
auto* new_shape_list = aot_request.add_per_core_output_shapes();
for (const auto& output_shape : output_shapes) {
*new_shape_list->add_shapes() = output_shape.ToProto();
}
}
for (const auto& variable_indices : per_core_variable_indices) {
auto* new_list = aot_request.add_per_core_variable_indices();
for (const auto& variable_index : variable_indices) {
auto* core_index = new_list->add_variable_indices();
core_index->set_index(variable_index.first);
core_index->set_updated(variable_index.second);
}
}
XlaCompilationResultProto compilation_result_proto =
SerializeCompilationResult(compilation_result);
*aot_request.mutable_compilation_result() = compilation_result_proto;
VLOG(1) << "TpuAotCompilationRequest:\n" << aot_request.DebugString();
return aot_request;
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,122 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
#include <string>
#include <vector>
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
namespace tensorflow {
namespace tpu {
namespace se = ::stream_executor;
// Describes the position of an argument or return value after the computation
// has been partitioned into cores.
struct ShardingAndIndex {
// Sharding across cores.
::xla::OpSharding sharding;
// Argument/return value number. If sharding is single-core, `indices` has a
// single element; otherwise, it has num_cores elements.
std::vector<int> indices;
};
// TODO(b/158279168): Dedup with internal version.
// Return the per-device shape for a `shape` with a given `sharding`.
xla::Shape GetPerDeviceShape(const xla::Shape& shape,
const xla::HloSharding& sharding,
int64 device);
stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
CreateModuleConfig(
const xla::ProgramShape& program_shape,
absl::Span<const xla::Shape> argument_shapes,
absl::optional<const xla::Shape> result_layout,
absl::optional<const xla::DeviceAssignment> device_assignment,
int replica_count, int num_partitions,
const xla::DebugOptions* debug_options, const int* seed,
const int* launch_id, const bool* alias_passthrough_params,
const xla::FusionConfigCollection* fusion_config_collection,
const std::vector<std::vector<bool>>* fusion_config);
stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
CreateModuleConfig(
const xla::ProgramShape& program_shape,
absl::Span<const xla::Shape> argument_shapes,
absl::optional<const xla::Shape> result_layout,
absl::optional<const xla::DeviceAssignment> device_assignment,
int replica_count,
int num_partitions, const xla::DebugOptions* debug_options);
xla::ShapeTree<xla::HloSharding> GetSubtree(
const xla::ShapeTree<xla::HloSharding>& tuple_shape_tree,
int element_index);
xla::Shape GetPerDeviceShape(const xla::Shape& shape,
const xla::HloSharding& sharding,
int64 device);
Status AddVariableUpdatesToCores(
const TPUCompileMetadataProto& metadata,
const XlaCompiler::CompilationResult& compilation_result,
const std::vector<ShardingAndIndex>& arg_core_mapping,
std::vector<bool>* may_modify_variables,
std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices);
se::port::Status ComputeOutputShapesForEachCore(
const tpu::TPUCompileMetadataProto& metadata,
const XlaCompiler::CompilationResult& compilation_result,
std::vector<std::vector<xla::Shape>>* per_core_output_shapes);
se::port::Status CreateHloModules(
const TPUCompileMetadataProto& metadata,
const XlaCompiler::CompilationResult& compilation_result,
const absl::optional<xla::DeviceAssignment>& device_assignment,
std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules);
se::port::StatusOr<TpuAotCompilationRequestProto>
CreateTpuAotCompilationRequest(
const xla::HloModuleGroup& module_group,
const XlaCompiler::CompilationResult& compilation_result,
const TPUCompileMetadataProto& metadata,
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
const std::vector<std::vector<std::pair<int, bool>>>&
per_core_variable_indices,
const absl::optional<xla::DeviceAssignment>& device_assignment);
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_

View File

@ -0,0 +1,298 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
#include <cstdint>
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
#include "tensorflow/core/tpu/tpu_config_c_api.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
namespace tensorflow {
namespace {
Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
tpu::TpuMeshStateInterface** state) {
if (!rmgr->Lookup(rmgr->default_container(),
tpu::kTpuMeshCommonStateResourceName, state)
.ok()) {
return errors::FailedPrecondition(
"The TPU system has not been initialized.");
}
return Status::OK();
}
// Attempt to delete resource_name from resource_manager's default_container.
// Returns OK if the deletion succeeded, or if the resource was not found. Else
// return the deletion error.
template <class ResourceT>
Status DeleteIfExists(ResourceMgr* resource_manager,
const char* resource_name) {
VLOG(1) << "Removing resource " << resource_name << " if it exists";
Status status = resource_manager->Delete<ResourceT>(
resource_manager->default_container(), resource_name);
if (status.ok()) {
VLOG(1) << "Removed existing resource " << resource_name;
return Status::OK();
}
if (status.code() == error::NOT_FOUND) {
VLOG(1) << "No resource " << resource_name << " to remove";
return Status::OK();
}
VLOG(1) << "Error removing resource " << resource_name << " : " << status;
return status;
}
} // namespace
void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "ConfigureDistributedTpuOp";
XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
std::vector<int32_t> num_devices_per_host;
int chips_per_host = -1;
for (int i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& input_tensor = ctx->input(i);
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(input_tensor.shape()),
errors::InvalidArgument("Input ", i, " should be a scalar but has ",
input_tensor.dims(), " dimensions"));
if (chips_per_host == -1) {
chips_per_host = input_tensor.scalar<int32_t>()();
} else {
OP_REQUIRES(
ctx, chips_per_host == input_tensor.scalar<int32>()(),
errors::Internal("Host ", i, " has ", input_tensor.scalar<int32>()(),
" TPU chips but host 0 has ", chips_per_host));
}
num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
}
TF_Status* status = TF_NewStatus();
size_t host_config_output_size;
char* host_config_output;
auto* rmgr = GetTPUConfigResourceMgr();
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
rmgr, tpu::kTpuMeshCommonStateResourceName));
ConfigureDistributedTpuOp_DoWork(
num_devices_per_host.size(), num_devices_per_host.data(),
&host_config_output_size, &host_config_output, status);
OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
tpu::kTpuMeshCommonStateResourceName,
tpu::TpuMeshStateInterface::Create()));
Tensor* ctx_output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
ctx_output->scalar<tstring>()() =
std::string(host_config_output, host_config_output_size);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status);
TpuConfigurationApi_FreeCharArray(host_config_output);
VLOG(1) << "ConfigureDistributedTpuOp done";
}
void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "WaitForDistributedTpuOp";
XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp");
size_t num_devices_per_host = -1;
size_t num_hosts = ctx->num_inputs();
for (int i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
OP_REQUIRES(
ctx, host_ordinal_to_global_device_id_tensor.dims() == 1,
errors::InvalidArgument("Input ", i, " should be a vector but has ",
host_ordinal_to_global_device_id_tensor.dims(),
" dimensions"));
}
std::vector<std::vector<int32_t>> mapping;
std::vector<int32_t*> mapping_arg;
mapping.resize(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
const auto host_ordinal_to_global_device_id =
host_ordinal_to_global_device_id_tensor.flat<int>();
if (num_devices_per_host == -1) {
num_devices_per_host =
host_ordinal_to_global_device_id_tensor.dim_size(0);
} else {
OP_REQUIRES(ctx,
num_devices_per_host ==
host_ordinal_to_global_device_id_tensor.dim_size(0),
errors::Internal(
"Host ", i, " has ",
host_ordinal_to_global_device_id_tensor.dim_size(0),
" TPU devices but host 0 has ", num_devices_per_host));
}
for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0);
++j) {
int32_t global_device_id = host_ordinal_to_global_device_id(j);
mapping[i].push_back(global_device_id);
}
mapping_arg.push_back(mapping[i].data());
}
TF_Status* status = TF_NewStatus();
size_t tpu_topology_output_size;
char* tpu_topology_output;
tpu::TpuMeshStateInterface* mesh_state;
auto* rmgr = GetTPUConfigResourceMgr();
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
core::ScopedUnref mesh_state_unref(mesh_state);
WaitForDistributedTpuOp_DoWork(
num_hosts, num_devices_per_host,
const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
&tpu_topology_output_size, &tpu_topology_output, status);
Tensor* ctx_output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
ctx_output->scalar<tstring>()() =
std::string(tpu_topology_output, tpu_topology_output_size);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status);
TpuConfigurationApi_FreeCharArray(tpu_topology_output);
VLOG(1) << "WaitForDistributedTpuOp done";
}
void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "ShutdownDistributedTpuOp";
XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
TF_Status* status = TF_NewStatus();
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
GetTPUConfigResourceMgr(),
tpu::kTpuMeshCommonStateResourceName));
ShutdownDistributedTpuOp_DoWork(status);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status);
VLOG(1) << "ShutdownDistributedTpuOp done";
}
void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "InitializeHostForDistributedTpuOp";
XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
auto tpu_host_config = ctx->input(0).scalar<tstring>()();
size_t device_id_output_size;
int32_t* device_id_output;
TF_Status* status = TF_NewStatus();
InitializeHostForDistributedTpuOp_DoWork(
tpu_host_config.size(), tpu_host_config.data(),
enable_whole_mesh_compilations_, &device_id_output_size,
&device_id_output, status);
Tensor* ctx_output;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(
0, TensorShape({static_cast<long long>(device_id_output_size)}),
&ctx_output));
for (size_t i = 0; i < device_id_output_size; ++i) {
ctx_output->flat<int32>()(i) = device_id_output[i];
}
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status);
TpuConfigurationApi_FreeInt32Array(device_id_output);
VLOG(1) << "InitializeHostForDistributedTpuOp done";
}
void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "SetGlobalTPUArrayOp";
XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp");
auto tpu_topology = ctx->input(0).scalar<tstring>()();
TF_Status* status = TF_NewStatus();
SetGlobalTPUArrayOp_DoWork(tpu_topology.size(), tpu_topology.data(), status);
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status);
VLOG(1) << "SetGlobalTPUArrayOp done";
}
void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "DisconnectDistributedTpuChipsOp";
XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp");
TF_Status* status = TF_NewStatus();
int32_t number_of_chips_output = 0;
DisconnectDistributedTpuChipsOp_DoWork(&number_of_chips_output, status);
Tensor* ctx_output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
ctx_output->scalar<int32_t>()() = number_of_chips_output;
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
TF_DeleteStatus(status);
VLOG(1) << "DisconnectDistributedTpuChipsOp done";
}
// These ops execute on the TPU_SYSTEM device only.
REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU")
.Device(DEVICE_TPU_SYSTEM)
.HostMemory("output"),
ConfigureDistributedTpuOp);
REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU")
.Device(DEVICE_TPU_SYSTEM)
.HostMemory("inputs")
.HostMemory("topology"),
WaitForDistributedTpuOp);
REGISTER_KERNEL_BUILDER(
Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM),
ShutdownDistributedTpuOp);
REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU")
.Device(DEVICE_TPU_SYSTEM)
.HostMemory("input")
.HostMemory("tpu_ids"),
InitializeHostForDistributedTpuOp);
REGISTER_KERNEL_BUILDER(
Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"),
SetGlobalTPUArrayOp);
REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem")
.Device(DEVICE_TPU_SYSTEM)
.HostMemory("number_of_tpu_chips"),
DisconnectDistributedTpuChipsOp);
} // namespace tensorflow

View File

@ -0,0 +1,156 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
// The ConfigureDistributedTpu op is used to start an TPUDriver from
// TensorFlow. It should be run on a TPU_SYSTEM device and returns the
// connection host:port for the CompilationCacheServer. The
// CompilationCacheServer will remain live until the device's Resource Manager
// is cleared or a ShutdownDistributedTpuOp is run on the same device.
class ConfigureDistributedTpuOp : public OpKernel {
public:
explicit ConfigureDistributedTpuOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES(
ctx, ctx->num_inputs() > 0,
errors::Internal("_ConfigureDistributedTPU needs at least one input"));
}
void Compute(OpKernelContext* ctx) override;
~ConfigureDistributedTpuOp() override {}
private:
// ConfigureDistributedTpuOp is neither copyable nor movable.
ConfigureDistributedTpuOp(const ConfigureDistributedTpuOp&) = delete;
ConfigureDistributedTpuOp& operator=(const ConfigureDistributedTpuOp&) =
delete;
};
// The WaitForDistributedTpuOp op is used to block execution until
// the distributed Tpu system has started up. It must be run on
// the same TPU_SYSTEM device that ConfigureDistributedTpuOp was run
// on, after all of the InitializeHostForDistributedTpuOp Ops have
// completed.
class WaitForDistributedTpuOp : public OpKernel {
public:
explicit WaitForDistributedTpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx,
ctx->GetAttr("startup_timeout_sec", &startup_timeout_sec_));
OP_REQUIRES(ctx, startup_timeout_sec_ > 0,
errors::InvalidArgument("startup_timeout_sec ",
startup_timeout_sec_, " must be >0"));
}
void Compute(OpKernelContext* ctx) override;
~WaitForDistributedTpuOp() override {}
private:
// The time to wait for all hosts to start up.
int startup_timeout_sec_;
// WaitForDistributedTpuOp is neither copyable nor movable.
WaitForDistributedTpuOp(const WaitForDistributedTpuOp&) = delete;
WaitForDistributedTpuOp& operator=(const WaitForDistributedTpuOp&) = delete;
};
// The ShutdownDistributedTpu op is used to stop a running TPUDriver from
// TensorFlow. It should be run on the TPU_SYSTEM device where
// ConfigureDistributedTpuOp was run.
class ShutdownDistributedTpuOp : public OpKernel {
public:
explicit ShutdownDistributedTpuOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override;
~ShutdownDistributedTpuOp() override {}
private:
// ShutdownDistributedTpuOp is neither copyable nor movable.
ShutdownDistributedTpuOp(const ShutdownDistributedTpuOp&) = delete;
ShutdownDistributedTpuOp& operator=(const ShutdownDistributedTpuOp&) = delete;
};
// The InitializeHostForDistributedTpu op is used to initialize the
// TPUPlatform on a host in a distributed TPU system. It should be
// run on every host containing TPU devices before any other Ops that use
// TPU are run.
class InitializeHostForDistributedTpuOp : public OpKernel {
public:
explicit InitializeHostForDistributedTpuOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
ctx->GetAttr("enable_whole_mesh_compilations",
&enable_whole_mesh_compilations_)
.IgnoreError();
}
void Compute(OpKernelContext* ctx) override;
~InitializeHostForDistributedTpuOp() override {}
private:
// InitializeHostForDistributedTpuOp is neither copyable nor movable.
InitializeHostForDistributedTpuOp(const InitializeHostForDistributedTpuOp&) =
delete;
InitializeHostForDistributedTpuOp& operator=(
const InitializeHostForDistributedTpuOp&) = delete;
bool enable_whole_mesh_compilations_ = false;
};
// The SetGlobalTPUArray op is used to initialize the TPUPlatform on a
// host in a distributed TPU system. It should be run on every host
// containing TPU devices before any other Ops that use TPU are run.
class SetGlobalTPUArrayOp : public OpKernel {
public:
explicit SetGlobalTPUArrayOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override;
~SetGlobalTPUArrayOp() override {}
private:
// SetGlobalTPUArrayOp is neither copyable nor movable.
SetGlobalTPUArrayOp(const SetGlobalTPUArrayOp&) = delete;
SetGlobalTPUArrayOp& operator=(const SetGlobalTPUArrayOp&) = delete;
};
// The DisconnectDistributedTpuChips op is used to disconnect all the chips on a
// host from a running TPUDriver instance. It should be run on every host
// containing TPU devices before the ShutdownDistributedTpuOp is run on
// the TPU_SYSTEM.
class DisconnectDistributedTpuChipsOp : public OpKernel {
public:
explicit DisconnectDistributedTpuChipsOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override;
~DisconnectDistributedTpuChipsOp() override {}
private:
// DisconnectDistributedTpuChipsOp is neither copyable nor movable.
DisconnectDistributedTpuChipsOp(const DisconnectDistributedTpuChipsOp&) =
delete;
DisconnectDistributedTpuChipsOp& operator=(
const DisconnectDistributedTpuChipsOp&) = delete;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_

View File

@ -0,0 +1,94 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow;
import "tensorflow/compiler/xla/service/hlo.proto";
import "tensorflow/compiler/xla/xla_data.proto";
import "tensorflow/core/framework/tensor_shape.proto";
// A serialization of TPUExecutable. Only includes fields necessary to load
// and execute a program on a worker node.
message TPUExecutableInfoProto {
reserved 1;
// The shapes of the inputs and outputs.
repeated xla.ShapeProto input_shapes = 2;
reserved 7; // was input_shape
xla.ShapeProto output_shape = 3;
message UpdateIndexPair {
int32 index = 1;
bool updated = 2;
}
message ShapeIndex {
repeated int32 index = 1;
}
// Dynamic output indices indicate which outputs have dynamic dimensions.
repeated ShapeIndex dynamic_output_indices = 11;
// For each resource variable output, what was the index of the corresponding
// input and was it updated? The indices are sorted by input order.
repeated UpdateIndexPair variable_indices = 10;
// The shapes of the outputs when represented as Tensors. These may not
// match the output_shape values because we may flatten tensors to avoid
// excess padding.
repeated TensorShapeProto output_tensor_shapes = 8;
reserved 4;
// Optional session module for passing XLA computations between TPUCompileOp
// and TPUExecuteOp. This is needed to support the
// --xla_dump_hlo_snapshots flag.
xla.HloSnapshot session_module = 5;
// The physical device ids assigned to the replicated cores.
xla.DeviceAssignmentProto device_assignment = 6;
}
// Metadata for a data transfer between device and host.
message TPUHostTransferProto {
enum TransferDirection {
NONE = 0;
DEVICE_TO_HOST = 1;
HOST_TO_DEVICE = 2;
}
// Channel identifier assigned by compiler and used in host commands.
int64 channel = 1;
// Direction of the transfer operation.
TransferDirection direction = 2;
// Channel identifier prodided by XLA client.
string key = 3;
// Depth of nested loops for this transfer operation.
int64 nested_while_level = 4;
// Shape of the data to be transferred (including layout).
xla.ShapeProto shape = 5;
// Address of the device buffer in HBM (byte offset).
int64 buffer_offset = 6;
// Original data type for this host transfer before X64 rewrite.
xla.PrimitiveType original_type = 7;
// If this host transfer is a splitted X64 transfer, sepcifies whether this
// transfer is for lower bits.
bool is_lower_bits = 8;
}
message TPUHostTransferInfoProto {
repeated TPUHostTransferProto host_transfers = 1;
}

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
typedef struct XLA_TpuMeshState XLA_TpuMeshState;
// Creates a new TPU mesh state object.
XLA_TpuMeshState* TpuMeshState_Create();
// Deletes the given TPU `mesh_state` object. Once deleted the object is
// unusable.
void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
// Returns a pointer to an opaque mesh data structure used internally.
void* TpuMeshState_MeshCommonState(XLA_TpuMeshState* mesh_state);
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
#include <string>
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
namespace tensorflow {
class TpuMeshCommonState;
namespace tpu {
const char kTpuMeshCommonStateResourceName[] = "tpu_mesh_common_state";
class TpuMeshStateInterface : public tensorflow::ResourceBase {
public:
explicit TpuMeshStateInterface(XLA_TpuMeshState* handle)
: mesh_state_(handle) {
}
~TpuMeshStateInterface() override {
if (mesh_state_ != nullptr) {
TpuMeshState_Free(mesh_state_);
}
}
static TpuMeshStateInterface* Create() {
return new TpuMeshStateInterface(TpuMeshState_Create());
}
const XLA_TpuMeshState* data() const { return mesh_state_; }
tensorflow::TpuMeshCommonState* mesh_common_state() const {
return static_cast<tensorflow::TpuMeshCommonState*>(
TpuMeshState_MeshCommonState(mesh_state_));
}
// Returns whether we should include the device assignment as a static field
// to the TPU program. This also determines whether we should include the
// device assignment as part of the compilation cache key.
bool NeedsStaticDeviceAssignment(
const TPUCompileMetadataProto& metadata,
TpuCoreTypeEnum tpu_core_type) const {
// Static device assignment enables XLA to perform certain optimization when
// all cores are used in the replicated computation.
return metadata.num_cores_per_replica() * metadata.num_replicas() ==
TpuTopology_AvailableCoreCount(mesh_state_,
tpu_core_type);
}
string DebugString() const override { return "TpuMeshStateInterface"; }
private:
XLA_TpuMeshState* mesh_state_;
};
} // namespace tpu
} // namespace tensorflow
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_

View File

@ -0,0 +1,201 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_program.h"
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
namespace tensorflow {
namespace tpu {
namespace {
namespace se_tpu = ::stream_executor::tpu;
using stream_executor::port::StatusOr;
using xla::Shape;
StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
std::unique_ptr<xla::HloModuleGroup> module_group,
const XlaCompiler::CompilationResult& compilation_result,
const TPUCompileMetadataProto& metadata,
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
const std::vector<std::vector<std::pair<int, bool>>>&
per_core_variable_indices,
const absl::optional<xla::DeviceAssignment>& device_assignment) {
VLOG(1) << "Run CompileAheadOfTime.";
TF_ASSIGN_OR_RETURN(TpuAotCompilationRequestProto aot_request,
CreateTpuAotCompilationRequest(
*module_group, compilation_result, metadata,
per_core_arg_shapes, per_core_output_shapes,
per_core_variable_indices, device_assignment));
se_tpu::SerializedProto serialized_aot_request =
se_tpu::SerializeProto(aot_request);
auto cleanup = gtl::MakeCleanup([serialized_aot_request] {
se_tpu::SerializedProto_Free(serialized_aot_request);
});
XLA_TpuProgram** xla_tpu_programs = nullptr;
size_t count = 0;
StatusHelper status;
VLOG(1) << "Run TpuCompile_CompileAheadOfTime.";
TpuCompile_CompileAheadOfTime(serialized_aot_request, &xla_tpu_programs,
&count, status.c_status);
VLOG(1) << "Run CompileAheadOfTime completed.";
if (!status.status().ok()) {
return status.status();
}
std::vector<XLA_TpuProgram*> tpu_programs(count, nullptr);
for (size_t i = 0; i < count; ++i) {
tpu_programs[i] = xla_tpu_programs[i];
}
delete[] xla_tpu_programs;
return tpu_programs;
return Status::OK();
}
StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
const TPUCompileMetadataProto& metadata,
const XlaCompiler::CompilationResult& compilation_result,
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
const std::vector<std::vector<std::pair<int, bool>>>&
per_core_variable_indices,
const absl::optional<xla::DeviceAssignment>& device_assignment) {
VLOG(1) << "Compile Tpu programs.";
std::vector<std::unique_ptr<xla::HloModule>> hlo_modules;
auto status = CreateHloModules(metadata, compilation_result,
device_assignment, &hlo_modules);
if (!status.ok()) {
return status;
}
return CompileAheadOfTime(
absl::make_unique<xla::HloModuleGroup>(hlo_modules[0]->name(),
absl::MakeSpan(hlo_modules)),
compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes,
per_core_variable_indices, device_assignment);
}
} // namespace
int64_t TpuProgram::program_size() const {
int64_t total_size = 0;
for (XLA_TpuProgram* tpu_program : tpu_programs_) {
total_size += TpuProgram_GetProgramSize(tpu_program);
}
return total_size;
}
bool TpuProgram::LogProgramMemorySummary() {
bool success = true;
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
success &= TpuProgram_LogProgramMemorySummary(tpu_program);
}
return success;
}
void TpuProgram::UnloadAndDestroyPrograms() {
for (XLA_TpuProgram* tpu_program : tpu_programs_) {
StatusHelper status;
TpuProgram_UnloadAndDestroy(tpu_program, status.c_status);
auto s = status.status();
if (!s.ok()) {
LOG(ERROR) << "TpuProgram::UnloadPrograms(): " << s.ToString();
}
}
tpu_programs_.clear();
}
/*static*/ Status TpuProgram::Build(
const TPUCompileMetadataProto& metadata,
const tensorflow::XlaCompiler::CompilationResult& compilation_result,
const std::vector<ShardingAndIndex>& arg_core_mapping,
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
TpuProgram* tpu_program) {
std::vector<std::vector<xla::Shape>> per_core_output_shapes(
metadata.num_cores_per_replica());
TF_RETURN_IF_ERROR(ComputeOutputShapesForEachCore(
metadata, compilation_result, &per_core_output_shapes));
std::vector<std::vector<std::pair<int, bool>>> per_core_variable_indices(
metadata.num_cores_per_replica());
std::vector<bool> may_modify_variables;
TF_RETURN_IF_ERROR(AddVariableUpdatesToCores(
metadata, compilation_result, arg_core_mapping, &may_modify_variables,
&per_core_output_shapes, &per_core_variable_indices));
TF_RET_CHECK(per_core_arg_shapes.size() == metadata.num_cores_per_replica());
TF_RET_CHECK(per_core_output_shapes.size() == per_core_arg_shapes.size());
TF_RET_CHECK(per_core_output_shapes.size() ==
per_core_variable_indices.size());
tpu_program->set_may_modify_variables(may_modify_variables);
// With shardable input/output pairs, XLA could generate separate
// sharding/unsharding programs along with the main program. The
// sharding/unsharding programs will be in nested entries of the AOT
// compilation result.
auto status_or = CompileAheadOfTime(
metadata, compilation_result, per_core_arg_shapes, per_core_output_shapes,
per_core_variable_indices, xla_device_assignment);
TF_ASSIGN_OR_RETURN(std::vector<XLA_TpuProgram*> xla_tpu_programs,
std::move(status_or));
// SPMD could return 1 result for all partitions.
TF_RET_CHECK(xla_tpu_programs.size() == 1 ||
xla_tpu_programs.size() == metadata.num_cores_per_replica());
tpu_program->set_tpu_programs(xla_tpu_programs);
// TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1.
TpuSerializedProto serialized_executable_info;
TpuProgram_GetExecutableInfo(xla_tpu_programs[0],
&serialized_executable_info);
TPUExecutableInfoProto executable_info =
se_tpu::DeserializeProto<TPUExecutableInfoProto>(
serialized_executable_info);
tpu_program->set_executable_info(executable_info);
StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
TPUHostTransferInfoProto host_transfer_info;
TpuSerializedProto serialized_host_transfer_info;
TpuProgram_GetHostTransferInfo(xla_tpu_programs[0],
&serialized_host_transfer_info);
if (serialized_host_transfer_info.size > 0) {
host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
serialized_host_transfer_info);
StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
}
tpu_program->set_host_transfer_info(host_transfer_info);
TpuSerializedProto serialized_hlo_metadata;
TpuProgram_GetHloMetadata(xla_tpu_programs[0], &serialized_hlo_metadata);
xla::HloProto hlo_metadata =
se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
tpu_program->set_hlo_metadata(hlo_metadata);
StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,161 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tensorflow {
namespace tpu {
class TpuAotCompilationOptions : public xla::AotCompilationOptions {
public:
explicit TpuAotCompilationOptions(int64 replica_count)
: num_cores_(0), replica_count_(replica_count) {}
// Returns the ID of the platform to which these options apply.
se::Platform::Id PlatformId() const override {
LOG(FATAL) << "Not implemented.";
return nullptr;
};
void set_num_cores(int64 tpu_cores) { num_cores_ = tpu_cores; }
int64 replica_count() const override { return replica_count_; }
int64 num_cores() const override { return num_cores_; }
void set_allow_separate_sharding_programs(bool allow) {
allow_separate_sharding_programs_ = allow;
}
bool allow_separate_sharding_programs() const {
return allow_separate_sharding_programs_;
}
const std::vector<xla::HloModuleConfig::ShardableValueUpdatePair>
shardable_value_update_pairs() const {
return shardable_value_update_pairs_;
}
void set_shardable_value_update_pairs(
std::vector<xla::HloModuleConfig::ShardableValueUpdatePair> pairs) {
shardable_value_update_pairs_ = std::move(pairs);
}
private:
int64 num_cores_;
int64 replica_count_;
// Whether to allow the compiler to create separte sharding and unsharding
// programs, and modify the original program's input/output sharded size. This
// is used for XLA-chosen sharding on parameters without an on-device loop:
// the caller can invoke sharding first, then (repeatedly) invoke the sharded
// main program, and finally invoke the unsharding program when it needs the
// full output.
bool allow_separate_sharding_programs_ = false;
// The list of input/output pairs in the main program that could be sharded.
std::vector<xla::HloModuleConfig::ShardableValueUpdatePair>
shardable_value_update_pairs_;
};
// An executable capable of being fed to a TPU device.
class TpuProgram {
public:
using Status = ::stream_executor::port::Status;
virtual ~TpuProgram() = default;
static Status Build(
const TPUCompileMetadataProto& metadata,
const tensorflow::XlaCompiler::CompilationResult& compilation_result,
const std::vector<ShardingAndIndex>& arg_core_mapping,
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
TpuProgram* tpu_program);
size_t program_count() const {
return tpu_programs_.size();
}
int64_t program_size() const;
bool LogProgramMemorySummary();
void UnloadAndDestroyPrograms();
const std::vector<bool>& may_modify_variables() const {
return may_modify_variables_;
}
void set_may_modify_variables(const std::vector<bool>& may_modify_variables) {
may_modify_variables_ = may_modify_variables;
}
const tf2xla::HostComputeMetadata& host_compute_metadata() const {
return host_compute_metadata_;
}
void set_host_compute_metadata(
const tf2xla::HostComputeMetadata& host_compute_metadata) {
host_compute_metadata_ = host_compute_metadata;
}
const std::vector<XLA_TpuProgram*>& tpu_programs() const {
return tpu_programs_;
}
void set_tpu_programs(std::vector<XLA_TpuProgram*> tpu_programs) {
tpu_programs_ = tpu_programs;
}
const TPUExecutableInfoProto& executable_info() const {
return executable_info_;
}
void set_executable_info(const TPUExecutableInfoProto& executable_info) {
executable_info_ = executable_info;
}
const TPUHostTransferInfoProto& host_transfer_info() const {
return host_transfer_info_;
}
void set_host_transfer_info(
const TPUHostTransferInfoProto& host_transfer_info) {
host_transfer_info_ = host_transfer_info;
}
const xla::HloProto& hlo_metadata() const { return hlo_metadata_; }
void set_hlo_metadata(const xla::HloProto& hlo_metadata) {
hlo_metadata_ = hlo_metadata;
}
private:
std::vector<bool> may_modify_variables_;
tf2xla::HostComputeMetadata host_compute_metadata_;
std::vector<XLA_TpuProgram*> tpu_programs_; // Not owned.
TPUExecutableInfoProto executable_info_;
TPUHostTransferInfoProto host_transfer_info_;
xla::HloProto hlo_metadata_;
};
} // namespace tpu
} // namespace tensorflow
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_

View File

@ -0,0 +1,100 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/platform/random.h"
namespace tensorflow {
namespace tpu {
std::string SessionNameFromMetadata(const SessionMetadata* session_metadata) {
return session_metadata ? session_metadata->name() : "";
}
std::string ProtoKeyForComputation(const std::string& key, int core) {
return absl::StrCat(key, ":", core);
}
xla::StatusOr<TpuCompilationCacheKey> ParseCompilationCacheKey(
const std::string& key) {
const std::vector<std::string> splits = absl::StrSplit(key, '|');
if (splits.size() == 1) {
// No guaranteed_const.
return TpuCompilationCacheKey(key);
} else if (splits.size() != 3) {
return errors::InvalidArgument("Invalid TPU compilation cache key:", key);
}
TpuCompilationCacheKey parsed_key(splits.at(0));
parsed_key.has_guaranteed_const = true;
parsed_key.session_handle = splits.at(1);
const string fingerprint = splits.at(2);
parsed_key.guaranteed_const_fingerprint = [fingerprint] {
return fingerprint;
};
return parsed_key;
}
xla::CompileOnlyClient::AotXlaComputationInstance
BuildAotXlaComputationInstance(
const XlaCompiler::CompilationResult& compilation_result) {
xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = compilation_result.computation.get();
for (const xla::Shape& shape : compilation_result.xla_input_shapes) {
instance.argument_layouts.push_back(&shape);
}
instance.result_layout = &compilation_result.xla_output_shape;
return instance;
}
Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape) {
if (tensor.dtype() != DT_INT64 ||
!TensorShapeUtils::IsVector(tensor.shape())) {
return errors::InvalidArgument("Shape tensor must be an int64 vector.");
}
const int64 rank = tensor.NumElements();
auto tensor_dims = tensor.flat<int64>();
std::vector<int64> dims(rank);
for (int64 i = 0; i < rank; ++i) {
dims[i] = tensor_dims(i);
}
return TensorShapeUtils::MakeShape(dims, shape);
}
Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
std::vector<TensorShape>* shapes) {
shapes->resize(dynamic_shapes.size());
for (int i = 0; i < dynamic_shapes.size(); ++i) {
TF_RETURN_IF_ERROR(
ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i]));
}
return Status::OK();
}
Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
std::vector<TensorShape>* shapes) {
shapes->resize(dynamic_shapes.end() - dynamic_shapes.begin());
size_t i = 0;
for (auto& dynamic_shape : dynamic_shapes) {
TF_RETURN_IF_ERROR(
ShapeTensorToTensorShape(dynamic_shape.tensor(), &(*shapes)[i]));
++i;
}
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,67 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/compile_only_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
namespace tensorflow {
namespace tpu {
// Utility to get session_name from `SessionMetadata`. `SessionMetadata` may
// be null.
std::string SessionNameFromMetadata(const SessionMetadata* session_metadata);
// Generates cache proto key for a given computation on a TPU core.
std::string ProtoKeyForComputation(const std::string& key, int core);
// Returns a TpuCompilationCacheKey parsed from given key or an error.
xla::StatusOr<TpuCompilationCacheKey> ParseCompilationCacheKey(
const std::string& key);
xla::CompileOnlyClient::AotXlaComputationInstance
BuildAotXlaComputationInstance(
const XlaCompiler::CompilationResult& compilation_result);
// Returns true if TPU compilation is enabled.
bool IsTpuCompilationEnabled();
// Converts an int64 host memory `tensor` to a `shape`.
Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape);
Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
std::vector<TensorShape>* shapes);
Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
std::vector<TensorShape>* shapes);
// Given a tensor of `shape` and `type`, as what shape should it be stored on
// the TPU device? This function tranposes or flattens the excessively-padded
// tensors to rank 1, but leaves other tensor shapes alone.
xla::StatusOr<xla::Shape> TpuShapeRepresentation(const TensorShape& shape,
DataType type,
bool use_fast_memory);
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
#ifdef PLATFORM_GOOGLE
#include "base/tracer.h"
#else
#undef TRACESTRING
#define TRACESTRING(x)
#undef TRACELITERAL
#define TRACELITERAL(x)
#endif
#endif // TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_

View File

@ -64,13 +64,20 @@ TfTpu_ConfigApiFn* ConfigApiFn() {
}
Status InitializeTpuLibrary(void* library_handle) {
bool shared_object_loaded = true;
if (library_handle == nullptr) {
library_handle = dlopen(nullptr, RTLD_LAZY);
shared_object_loaded = false;
}
TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle));
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
if (shared_object_loaded) {
// Initialize TPU platform when the platform code is loaded from a library.
InitializeApiFn()->TfTpu_InitializeFn();
}
return Status::OK();
}

View File

@ -0,0 +1,234 @@
# Description: StreamExecutor Interface for TPUs
package(
default_visibility = ["//tensorflow/core/tpu:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "tpu_executor_c_api_hdrs",
hdrs = ["tpu_executor_c_api.h"],
deps = [
"//tensorflow/c:tf_attrtype",
"//tensorflow/c:tf_datatype",
"//tensorflow/c:tf_status",
],
)
cc_library(
name = "tpu_node_context_c_api_hdrs",
hdrs = ["tpu_node_context_c_api.h"],
deps = [
":tpu_executor_c_api_hdrs",
],
)
cc_library(
name = "status_helper",
hdrs = ["status_helper.h"],
deps = [
":tpu_executor_c_api_hdrs",
"//tensorflow/core/platform:status",
],
)
cc_library(
name = "c_api_conversions",
hdrs = ["c_api_conversions.h"],
deps = [
":tpu_executor_c_api_hdrs",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/stream_executor:stream",
"@com_google_absl//absl/container:inlined_vector",
],
)
cc_library(
name = "proto_helper",
srcs = ["proto_helper.cc"],
hdrs = ["proto_helper.h"],
deps = ["//tensorflow/core:lib"],
)
cc_library(
name = "tpu_stream",
hdrs = ["tpu_stream.h"],
deps = [
":tpu_executor_c_api_hdrs",
"//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/lib",
],
)
cc_library(
name = "tpu_timer",
hdrs = ["tpu_timer.h"],
deps = [
":tpu_executor_c_api_hdrs",
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor:stream",
],
)
cc_library(
name = "tpu_executor",
srcs = ["tpu_executor.cc"],
hdrs = ["tpu_executor.h"],
deps = [
":c_api_conversions",
":status_helper",
":tpu_executor_c_api_hdrs",
":tpu_executor_interface",
":tpu_platform",
":tpu_platform_interface",
":tpu_stream",
":tpu_timer",
"//tensorflow/c:tf_status",
"//tensorflow/core:lib",
"//tensorflow/stream_executor:stream",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
],
)
cc_library(
name = "tpu_executor_hdrs",
hdrs = ["tpu_executor.h"],
deps = [
":tpu_executor_c_api_hdrs",
":tpu_executor_interface",
":tpu_platform_hdrs",
":tpu_platform_interface",
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor:stream_header",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
],
)
cc_library(
name = "tpu_platform_hdrs",
hdrs = ["tpu_platform.h"],
deps = [
":tpu_executor_c_api_hdrs",
":tpu_platform_interface",
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor:stream_header",
"@com_google_absl//absl/container:flat_hash_map",
],
)
cc_library(
name = "tpu_node_context",
srcs = ["tpu_node_context.cc"],
hdrs = ["tpu_node_context.h"],
deps = [
":status_helper",
":tpu_executor_c_api_hdrs",
":tpu_node_context_c_api_hdrs",
":tpu_platform_interface",
":tpu_transfer_manager",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:framework",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/memory",
],
)
cc_library(
name = "tpu_platform",
srcs = ["tpu_platform.cc"],
hdrs = ["tpu_platform.h"],
deps = [
":status_helper",
":tpu_executor_c_api_hdrs",
":tpu_executor_hdrs",
":tpu_platform_interface",
"//tensorflow/c:tf_status",
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor:stream",
"@com_google_absl//absl/container:flat_hash_map",
],
alwayslink = True,
)
cc_library(
name = "tpu_transfer_manager",
srcs = ["tpu_transfer_manager_registration.cc"],
deps = [
":tpu_platform",
":tpu_transfer_manager_base",
"//tensorflow/compiler/xla/service:transfer_manager",
],
)
cc_library(
name = "tpu_transfer_manager_base",
srcs = ["tpu_transfer_manager.cc"],
hdrs = ["tpu_transfer_manager.h"],
deps = [
":c_api_conversions",
":proto_helper",
":status_helper",
":tpu_executor_c_api_hdrs",
":tpu_platform",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/stream_executor:stream",
],
)
cc_library(
name = "tpu_computation_placer",
srcs = ["tpu_computation_placer.cc"],
hdrs = ["tpu_computation_placer.h"],
deps = [
":tpu_executor_c_api_hdrs",
":tpu_platform",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:computation_placer",
],
alwayslink = True,
)
cc_library(
name = "tpu_platform_interface",
srcs = ["tpu_platform_interface.cc"],
hdrs = ["tpu_platform_interface.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor:multi_platform_manager",
"//tensorflow/stream_executor:stream_executor_headers",
],
)
cc_library(
name = "tpu_stream_interface",
hdrs = ["tpu_stream_interface.h"],
visibility = ["//visibility:public"],
deps = ["//tensorflow/stream_executor:stream_executor_internal"],
)
cc_library(
name = "tpu_executor_interface",
hdrs = ["tpu_executor_interface.h"],
visibility = ["//visibility:public"],
deps = [
":tpu_platform_interface",
"//tensorflow/core/platform:errors",
"//tensorflow/stream_executor:stream_executor_internal",
"//tensorflow/stream_executor:stream_header",
],
)

View File

@ -0,0 +1,115 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
class TpuConversions {
public:
static stream_executor::DeviceMemoryBase
SE_DeviceMemoryBaseToDeviceMemoryBase(SE_DeviceMemoryBase se_base) {
stream_executor::DeviceMemoryBase base(se_base.opaque, se_base.size);
base.SetPayload(se_base.payload);
return base;
}
static SE_DeviceMemoryBase DeviceMemoryBaseToSE_DeviceMemoryBase(
const stream_executor::DeviceMemoryBase& base) {
SE_DeviceMemoryBase se_base;
se_base.opaque = const_cast<void*>(base.opaque());
se_base.payload = base.payload();
se_base.size = base.size();
return se_base;
}
static xla::Shape CShapeToXlaShape(XLA_Shape* shape) {
xla::ShapeProto p;
p.ParseFromArray(shape->bytes, shape->size);
return xla::Shape(p);
}
static void XlaShapeToCShape(const xla::Shape& xla_shape,
XLA_Shape* c_shape) {
xla::ShapeProto p = xla_shape.ToProto();
std::string p_str = p.SerializeAsString();
c_shape->bytes = new char[p_str.size()];
c_shape->size = p_str.size();
memcpy(c_shape->bytes, p_str.data(), p_str.size());
}
static void XLAShapedBufferToCShapedBuffer(
const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) {
XlaShapeToCShape(buffer.on_host_shape(), &c_device_buffer->on_host_shape);
XlaShapeToCShape(buffer.on_device_shape(),
&c_device_buffer->on_device_shape);
c_device_buffer->device_ordinal = buffer.device_ordinal();
absl::InlinedVector<SE_DeviceMemoryBase, 2> bases;
for (auto& pair : buffer.buffers()) {
bases.push_back(DeviceMemoryBaseToSE_DeviceMemoryBase(pair.second));
}
c_device_buffer->count = bases.size();
c_device_buffer->bases = new SE_DeviceMemoryBase[bases.size()];
for (int i = 0; i < bases.size(); ++i) {
c_device_buffer->bases[i] = bases[i];
}
}
static void XLALiteralToCLiteral(const xla::LiteralSlice& literal,
XLA_Literal* c_literal) {
XlaShapeToCShape(literal.shape(), &c_literal->shape);
auto shapes = xla::ShapeUtil::GetLeafShapes(literal.shape());
c_literal->buffers = new char*[shapes.size()];
c_literal->sizes = new size_t[shapes.size()];
c_literal->count = shapes.size();
for (int i = 0; i < shapes.size(); ++i) {
c_literal->buffers[i] = reinterpret_cast<char*>(
const_cast<void*>(literal.untyped_data(shapes[i].index)));
c_literal->sizes[i] = literal.size_bytes(shapes[i].index);
}
}
static xla::MutableBorrowingLiteral CLiteralToXLALiteral(
XLA_Literal* c_literal) {
xla::Shape shape = CShapeToXlaShape(&c_literal->shape);
LOG(INFO) << "Shape: " << shape.DebugString();
return xla::MutableBorrowingLiteral(
absl::MakeSpan(c_literal->buffers, c_literal->count), shape);
}
static void CShapeCleanup(XLA_Shape* c_shape) { delete[] c_shape->bytes; }
static void CLiteralCleanup(XLA_Literal* c_literal) {
delete[] c_literal->buffers;
delete[] c_literal->sizes;
CShapeCleanup(&c_literal->shape);
}
static void CShapedBufferCleanup(XLA_ShapedBuffer* c_buffer) {
CShapeCleanup(&c_buffer->on_device_shape);
CShapeCleanup(&c_buffer->on_host_shape);
delete[] c_buffer->bases;
}
};
#endif // THIRD_PARTY_TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/proto_helper.h"
extern "C" {
void StreamExecutor_Tpu_FreeSerializedProto(const TpuSerializedProto* proto) {
CHECK_NE(proto, nullptr);
CHECK_NE(proto->bytes, nullptr);
CHECK_GT(proto->size, 0);
delete[] proto->bytes;
}
} // extern "C"

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
#include <cstddef>
#include "tensorflow/core/platform/logging.h"
extern "C" {
typedef struct TpuSerializedProto {
const char* bytes;
size_t size;
} TpuSerializedProto;
void StreamExecutor_Tpu_FreeSerializedProto(const TpuSerializedProto* proto);
} // extern "C"
namespace stream_executor {
namespace tpu {
using SerializedProto = TpuSerializedProto;
// Serializes a proto and put the result in the given SerializedProto* argument.
//
// Users should call SerializedProto_Free on `serialized_proto` afterwards.
template <class Proto>
inline void SerializeProto(const Proto& proto,
SerializedProto* serialized_proto) {
auto size = proto.ByteSizeLong();
auto bytes = new char[size];
CHECK(proto.SerializeToArray(bytes, size));
serialized_proto->size = size;
serialized_proto->bytes = bytes;
}
// Serializes a proto and return the result as a SerializedProto value.
//
// Users should call SerializedProto_Free on the return value afterwards.
template <class Proto>
inline SerializedProto SerializeProto(const Proto& proto) {
SerializedProto serialized_proto;
SerializeProto(proto, &serialized_proto);
return serialized_proto;
}
// Deserializes a buffer and return the corresponding proto. If the buffer is
// empty, return an empty proto.
template <class Proto>
inline Proto DeserializeProto(const SerializedProto& serialized_proto) {
Proto proto;
if (serialized_proto.bytes != nullptr) {
CHECK_GT(serialized_proto.size, 0);
CHECK(proto.ParseFromArray(serialized_proto.bytes, serialized_proto.size))
<< "Invalid buffer, failed to deserialize buffer.";
}
return proto;
}
// Releases the memory allocated for serialized protos.
inline void SerializedProto_Free(const SerializedProto& serialized_proto) {
CHECK_NE(serialized_proto.bytes, nullptr);
CHECK_GT(serialized_proto.size, 0);
delete[] serialized_proto.bytes;
}
} // namespace tpu
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
#include "tensorflow/core/platform/status.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
struct StatusHelper {
StatusHelper() : c_status(TpuStatus_New()) {}
~StatusHelper() { TpuStatus_Free(c_status); }
bool ok() { return TpuStatus_Code(c_status) == 0; }
tensorflow::Status status() {
if (!ok()) {
return tensorflow::Status(
tensorflow::error::Code(TpuStatus_Code(c_status)),
TpuStatus_Message(c_status));
} else {
return tensorflow::Status::OK();
}
}
SE_Status* c_status;
};
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_

View File

@ -0,0 +1,51 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
template <typename T>
using StatusOr = TpuComputationPlacer::StatusOr<T>;
TpuComputationPlacer::TpuComputationPlacer() {
placer_ = TpuComputationPlacer_New();
}
TpuComputationPlacer::~TpuComputationPlacer() {
TpuComputationPlacer_Free(placer_);
}
StatusOr<int> TpuComputationPlacer::DeviceId(int replica, int computation,
int replica_count,
int computation_count) {
LOG(FATAL) << "Unimplemented.";
}
StatusOr<xla::DeviceAssignment> TpuComputationPlacer::AssignDevices(
int replica_count, int computation_count) {
LOG(FATAL) << "Unimplemented.";
}
static std::unique_ptr<xla::ComputationPlacer> CreateTpuComputationPlacer() {
return std::make_unique<TpuComputationPlacer>();
}
static bool InitModule() {
xla::ComputationPlacer::RegisterComputationPlacer(
tensorflow::TpuPlatform::kId, CreateTpuComputationPlacer);
return true;
}
static bool module_initialized = InitModule();

View File

@ -0,0 +1,41 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
class TpuComputationPlacer : public xla::ComputationPlacer {
public:
template <typename T>
using StatusOr = xla::StatusOr<T>;
TpuComputationPlacer();
~TpuComputationPlacer() override;
StatusOr<int> DeviceId(int replica, int computation, int replica_count,
int computation_count) override;
StatusOr<xla::DeviceAssignment> AssignDevices(int replica_count,
int computation_count) override;
private:
XLA_ComputationPlacer* placer_;
};
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_

View File

@ -0,0 +1,355 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
#include "tensorflow/stream_executor/tpu/tpu_timer.h"
using stream_executor::DeviceMemoryBase;
namespace tensorflow {
namespace {
using ::stream_executor::port::Status;
} // namespace
TpuExecutor::~TpuExecutor() { TpuExecutor_Free(executor_); }
Status TpuExecutor::Init(int device_ordinal,
::stream_executor::DeviceOptions device_options) {
StatusHelper status;
SE_DeviceOptions* options =
TpuExecutor_NewDeviceOptions(device_options.flags());
TpuExecutor_Init(executor_, device_ordinal, options, status.c_status);
TpuExecutor_FreeDeviceOptions(options);
return status.status();
}
int TpuExecutor::PlatformDeviceCount() {
return TpuExecutor_PlatformDeviceCount(executor_);
}
void TpuExecutor::SyncAndForgetFailedStreams() {
TpuExecutor_SyncAndForgetFailedStreams(executor_);
}
bool TpuExecutor::SynchronizeAllActivity() {
return TpuExecutor_SynchronizeAllActivity(executor_);
}
Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
StatusHelper status;
TpuExecutor_BlockHostUntilDone(
executor_, stream_map().at(stream->implementation()), status.c_status);
return status.status();
}
Status TpuExecutor::BlockUntilDoneOrFailed() {
StatusHelper status;
TpuExecutor_BlockUntilDoneOrFailed(executor_, status.c_status);
return status.status();
}
Status TpuExecutor::GetStatus(Stream* stream) {
StatusHelper status;
TpuExecutor_GetStatus(executor_, stream_map().at(stream->implementation()),
status.c_status);
return status.status();
}
bool TpuExecutor::AllocateStream(Stream* stream) {
return TpuExecutor_AllocateStream(executor_,
stream_map().at(stream->implementation()));
}
void TpuExecutor::DeallocateStream(Stream* stream) {
TpuExecutor_DeallocateStream(executor_,
stream_map().at(stream->implementation()));
stream_map().erase(stream->implementation());
}
bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
return TpuExecutor_CreateStreamDependency(
executor_, stream_map().at(dependent->implementation()),
stream_map().at(other->implementation()));
}
Status TpuExecutor::AllocateEvent(Event* event) { return Status::OK(); }
Status TpuExecutor::DeallocateEvent(Event* event) { return Status::OK(); }
// AllocateTimer/DeallocateTimer have no specialization.
bool TpuExecutor::AllocateTimer(Timer* timer) { return true; }
void TpuExecutor::DeallocateTimer(Timer* timer) {}
bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
return TpuExecutor_StartTimer(executor_,
stream_map().at(stream->implementation()),
timer_map_.at(timer->implementation()));
}
bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
return TpuExecutor_StopTimer(executor_,
stream_map().at(stream->implementation()),
timer_map_.at(timer->implementation()));
}
stream_executor::Event::Status TpuExecutor::PollForEventStatus(
stream_executor::Event* event) {
return stream_executor::Event::Status(TpuExecutor_PollForEventStatus(
executor_, event_map_.at(event->implementation())));
}
Status TpuExecutor::RecordEvent(Stream* stream,
::stream_executor::Event* event) {
StatusHelper status;
TpuExecutor_RecordEvent(executor_, stream_map().at(stream->implementation()),
event_map_.at(event->implementation()),
status.c_status);
return status.status();
}
Status TpuExecutor::WaitForEvent(Stream* stream,
::stream_executor::Event* event) {
StatusHelper status;
TpuExecutor_WaitForEvent(executor_, stream_map().at(stream->implementation()),
event_map_.at(event->implementation()),
status.c_status);
return status.status();
}
// Implementations for Timer, Stream, Event
// We need to map these implementations to internal equivalents -- thus we
// allocate the internal Timer, Stream and Event operations here, and map
// the implementations to the internal values. The "wrapper" interfaces are
// responsible for deallocating the internal value when they are destroyed.
// Called by Timer::Timer
std::unique_ptr<::stream_executor::internal::TimerInterface>
TpuExecutor::GetTimerImplementation() {
SE_Timer* tpu_timer = TpuTimer_New(executor_);
auto ptr = absl::make_unique<TpuTimer>(tpu_timer);
timer_map_[ptr.get()] = tpu_timer;
return ptr;
}
// Called by Stream::Stream
std::unique_ptr<::stream_executor::internal::StreamInterface>
TpuExecutor::GetStreamImplementation() {
SE_Stream* tpu_stream = TpuStream_New(executor_);
auto ptr = absl::make_unique<TpuStream>(tpu_stream);
stream_map()[ptr.get()] = tpu_stream;
return ptr;
}
// Called by Event::Event
std::unique_ptr<::stream_executor::internal::EventInterface>
TpuExecutor::CreateEventImplementation() {
SE_Event* tpu_event = TpuEvent_New(executor_);
auto ptr = absl::make_unique<TpuEvent>(tpu_event);
event_map_[ptr.get()] = tpu_event;
return ptr;
}
DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
SE_DeviceMemoryBase se_base =
TpuExecutor_Allocate(executor_, size, memory_space);
return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
}
void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
TpuExecutor_Deallocate(executor_, &se_base);
}
void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
TpuExecutor_Deallocate(executor_, &se_base);
}
bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
int64_t _free;
int64_t _total;
if (TpuExecutor_DeviceMemoryUsage(executor_, &_free, &_total)) {
*free = _free;
*total = _total;
return true;
}
return false;
}
absl::optional<stream_executor::AllocatorStats>
TpuExecutor::GetAllocatorStats() {
SE_AllocatorStats c_stats;
if (TpuExecutor_GetAllocatorStats(executor_, &c_stats)) {
::stream_executor::AllocatorStats stats;
stats.num_allocs = c_stats.num_allocs;
stats.bytes_in_use = c_stats.bytes_in_use;
stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
stats.largest_alloc_size = c_stats.largest_alloc_size;
if (c_stats.has_bytes_limit) {
stats.bytes_limit = c_stats.bytes_limit;
}
stats.bytes_reserved = c_stats.bytes_reserved;
stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
if (c_stats.has_bytes_reservable_limit) {
stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
}
stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
return stats;
}
return {};
}
Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) {
StatusHelper status;
TpuExecutor_WaitForInfeedReady(executor_, infeed_queue_index,
status.c_status);
return status.status();
}
Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) {
StatusHelper status;
TpuExecutor_WaitForOutfeedReady(executor_, outfeed_queue_index,
status.c_status);
return status.status();
}
void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index,
absl::Span<uint8> bytes, StatusCallback done) {
StatusHelper status;
TpuExecutor_DequeueOutfeed(executor_, outfeed_queue_index, bytes.data(),
bytes.size(), status.c_status);
done(status.status());
}
Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index,
absl::Span<const uint8> bytes) {
StatusHelper status;
TpuExecutor_EnqueueInfeed(executor_, infeed_queue_index, bytes.data(),
bytes.size(), status.c_status);
return status.status();
}
bool TpuExecutor::Memcpy(Stream* stream, void* host_dst,
const ::stream_executor::DeviceMemoryBase& device_src,
uint64 size) {
SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
return TpuExecutor_MemcpyToHost(executor_,
stream_map().at(stream->implementation()),
host_dst, &se_base, size);
}
bool TpuExecutor::Memcpy(Stream* stream,
::stream_executor::DeviceMemoryBase* device_dst,
const void* host_src, uint64 size) {
SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
return TpuExecutor_MemcpyFromHost(executor_,
stream_map().at(stream->implementation()),
&se_base, host_src, size);
}
Status TpuExecutor::SynchronousMemcpy(
::stream_executor::DeviceMemoryBase* device_dst, const void* host_src,
uint64 size) {
StatusHelper status;
SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
TpuExecutor_SynchronousMemcpyFromHost(executor_, &se_base, host_src, size,
status.c_status);
return status.status();
}
Status TpuExecutor::SynchronousMemcpy(
void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
uint64 size) {
StatusHelper status;
SE_DeviceMemoryBase se_base =
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
TpuExecutor_SynchronousMemcpyToHost(executor_, host_dst, &se_base, size,
status.c_status);
return status.status();
}
Status TpuExecutor::SynchronousMemcpyDeviceToDevice(
::stream_executor::DeviceMemoryBase* device_dst,
const ::stream_executor::DeviceMemoryBase& device_src, uint64 size) {
return ::stream_executor::port::UnimplementedError(
"This operation not supported on TPU");
}
bool TpuExecutor::MemcpyDeviceToDevice(
Stream* stream, ::stream_executor::DeviceMemoryBase* gpu_dst,
const ::stream_executor::DeviceMemoryBase& host_src, uint64 size) {
LOG(FATAL) << __func__ << " not supported on TpuExecutor";
}
struct HostCallbackContext {
std::function<Status()> callback;
};
SE_Status* HostCallbackTrampoline(void* ctx) {
HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
Status status = host_ctx->callback();
SE_Status* c_status =
TpuStatus_Create(status.code(), status.error_message().c_str());
delete host_ctx;
return c_status;
}
bool TpuExecutor::HostCallback(Stream* stream,
std::function<Status()> callback) {
HostCallbackContext* ctx = new HostCallbackContext{callback};
return TpuExecutor_HostCallback(executor_,
stream_map().at(stream->implementation()),
&HostCallbackTrampoline, ctx);
}
TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
TpuExecutor::CreateDeviceDescription() const {
StatusHelper status;
SE_DeviceDescription* description = TpuDeviceDescription_New();
auto cleanup = tensorflow::gtl::MakeCleanup(
[description]() { TpuDeviceDescription_Free(description); });
TpuExecutor_CreateDeviceDescription(executor_, description, status.c_status);
if (status.status().ok()) {
stream_executor::internal::DeviceDescriptionBuilder builder;
CHECK_NE(description->device_vendor, nullptr);
builder.set_device_vendor(description->device_vendor);
builder.set_name(description->name);
builder.set_clock_rate_ghz(description->clock_rate_ghz);
builder.set_core_count(description->core_count);
builder.set_ecc_enabled(description->ecc_enabled);
builder.set_device_memory_size(description->device_memory_size);
builder.set_platform_version(description->platform_version);
return builder.Build();
}
return status.status();
}
} // namespace tensorflow

View File

@ -0,0 +1,241 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/device_options.h"
#include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/temporary_device_memory.h"
#include "tensorflow/stream_executor/timer.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tensorflow {
class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
public:
using Status = ::stream_executor::port::Status;
template <typename T>
using StatusOr = ::stream_executor::port::StatusOr<T>;
using StatusCallback = std::function<void(const Status&)>;
using Stream = ::stream_executor::Stream;
using Event = ::stream_executor::Event;
using Timer = ::stream_executor::Timer;
using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase;
using StreamInterface = ::stream_executor::internal::StreamInterface;
using StreamExecutorInterface =
::stream_executor::internal::StreamExecutorInterface;
using EventMap =
absl::flat_hash_map<stream_executor::internal::EventInterface*,
SE_Event*>;
using TimerMap =
absl::flat_hash_map<stream_executor::internal::TimerInterface*,
SE_Timer*>;
explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform,
SE_StreamExecutor* executor)
: platform_(platform), executor_(executor) {}
~TpuExecutor() override;
Status Init(int device_ordinal,
::stream_executor::DeviceOptions device_options) override;
DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override;
StatusOr<DeviceMemoryBase> AllocateDeviceMemoryBase(uint64 size,
int64 memory_space);
Status AllocateEvent(Event* event) override;
bool AllocateStream(Stream* stream) override;
bool AllocateTimer(Timer* timer) override;
Status BlockHostUntilDone(::stream_executor::Stream* stream) override;
Status BlockUntilDoneOrFailed();
StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
CreateDeviceDescription() const override;
bool CreateStreamDependency(Stream* dependent, Stream* other) override;
void DeallocateStream(Stream* stream) override;
void Deallocate(const DeviceMemoryBase& memory);
void Deallocate(DeviceMemoryBase* memory) override;
Status DeallocateEvent(Event* event) override;
void DeallocateTimer(Timer* timer) override;
bool DeviceMemoryUsage(int64* free, int64* total) const override;
void DequeueOutfeed(int32 outfeed_queue_index, absl::Span<uint8> bytes,
StatusCallback done);
Status EnqueueInfeed(int32 infeed_queue_index,
absl::Span<const uint8> bytes);
absl::optional<stream_executor::AllocatorStats> GetAllocatorStats() override;
Status GetStatus(Stream* stream) override;
std::unique_ptr<::stream_executor::internal::StreamInterface>
GetStreamImplementation() override;
std::unique_ptr<::stream_executor::internal::TimerInterface>
GetTimerImplementation() override;
std::unique_ptr<::stream_executor::internal::EventInterface>
CreateEventImplementation() override;
bool HostCallback(Stream* stream, std::function<Status()> callback) override;
bool Memcpy(Stream* stream, void* host_dst,
const ::stream_executor::DeviceMemoryBase& device_src,
uint64 size) override;
bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst,
const void* host_src, uint64 size) override;
bool MemcpyDeviceToDevice(Stream* stream,
::stream_executor::DeviceMemoryBase* gpu_dst,
const ::stream_executor::DeviceMemoryBase& host_src,
uint64 size) override;
void SyncAndForgetFailedStreams();
bool SynchronizeAllActivity() override;
Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst,
const void* host_src, uint64 size) override;
Status SynchronousMemcpy(
void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
uint64 size) override;
Status SynchronousMemcpyDeviceToDevice(
::stream_executor::DeviceMemoryBase* device_dst,
const ::stream_executor::DeviceMemoryBase& device_src,
uint64 size) override;
int PlatformDeviceCount() override;
Event::Status PollForEventStatus(Event* event) override;
Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override;
Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override;
bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override;
bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override;
Status WaitForInfeedReady(int32 infeed_queue_index);
Status WaitForOutfeedReady(int32 outfeed_queue_index);
const ::tensorflow::tpu::TpuPlatformInterface& platform() const override {
return *platform_;
}
::tensorflow::tpu::TpuPlatformInterface& platform() override {
return *platform_;
}
// TODO(henrytan): convert this to override once the base interface is changed
// to TpuExecutorInterface.
StatusOr<std::unique_ptr<
tensorflow::tpu::TpuExecutorInterface::TemporaryDeviceMemory>>
CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset,
int64 size) override {
LOG(FATAL) << "Unimplemented.";
}
// -- Unimplemented (stubbed out) methods.
std::unique_ptr<stream_executor::internal::KernelInterface>
CreateKernelImplementation() override {
LOG(FATAL) << "Not yet implemented";
}
stream_executor::SharedMemoryConfig GetDeviceSharedMemoryConfig() override {
LOG(FATAL) << "not yet implemented";
}
void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
uint64 size) override {
LOG(FATAL) << "not yet implemented";
}
Status MemZero(Stream* stream, DeviceMemoryBase* location,
uint64 size) override {
LOG(FATAL) << "not yet implemented";
}
Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern,
uint64 size) override {
LOG(FATAL) << "not yet implemented";
}
Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
LOG(FATAL) << "not yet implemented";
}
bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
LOG(FATAL) << "not yet implemented";
}
Status SetDeviceSharedMemoryConfig(
stream_executor::SharedMemoryConfig config) override {
LOG(FATAL) << "not yet implemented";
}
void* HostMemoryAllocate(uint64 size) override {
LOG(FATAL) << "not yet implemented";
}
void HostMemoryDeallocate(void* mem) override {
LOG(FATAL) << "not yet implemented";
}
bool HostMemoryRegister(void* mem, uint64 size) override {
LOG(FATAL) << "not yet implemented";
}
bool HostMemoryUnregister(void* mem) override {
LOG(FATAL) << "not yet implemented";
}
Status SynchronousMemZero(DeviceMemoryBase* location, uint64 size) override {
LOG(FATAL) << "not yet implemented";
}
Status SynchronousMemSet(DeviceMemoryBase* location, int value,
uint64 size) override {
LOG(FATAL) << "not yet implemented";
}
private:
EventMap event_map_;
TimerMap timer_map_;
TpuPlatform::StreamMap& stream_map() {
return *(static_cast<TpuPlatform*>(platform_)->stream_map());
}
::tensorflow::tpu::TpuPlatformInterface* platform_;
SE_StreamExecutor* executor_;
};
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_

View File

@ -0,0 +1,293 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/c/tf_attrtype.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
typedef struct SE_Platform SE_Platform;
typedef struct SE_StreamExecutor SE_StreamExecutor;
typedef struct SE_Stream SE_Stream;
typedef struct SE_Event SE_Event;
typedef struct SE_Timer SE_Timer;
typedef struct SE_Status SE_Status;
typedef struct SE_PlatformId {
void* id; // aka stream_executor::Platform::Id
} SE_PlatformId;
typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig;
typedef struct SE_DeviceOptions SE_DeviceOptions;
typedef SE_Status* (*SE_StatusCallbackFn)(void*);
typedef struct SE_DeviceMemoryBase {
void* opaque;
uint64_t size;
uint64_t payload;
} SE_DeviceMemoryBase;
typedef struct SE_AllocatorStats {
int64_t num_allocs;
int64_t bytes_in_use;
int64_t peak_bytes_in_use;
int64_t largest_alloc_size;
bool has_bytes_limit;
int64_t bytes_limit;
int64_t bytes_reserved;
int64_t peak_bytes_reserved;
bool has_bytes_reservable_limit;
int64_t bytes_reservable_limit;
int64_t largest_free_block_bytes;
} SE_AllocatorStats;
typedef struct SE_DeviceDescription {
char* device_vendor;
char* platform_version;
char* driver_version;
char* runtime_version;
char* pci_bus_id;
char* name;
int64_t thread_dim_limit_x;
int64_t thread_dim_limit_y;
int64_t thread_dim_limit_z;
int64_t block_dim_limit_x;
int64_t block_dim_limit_y;
int64_t block_dim_limit_z;
int64_t threads_per_core_limit;
int64_t threads_per_block_limit;
int64_t threads_per_warp;
int64_t registers_per_core_limit;
int64_t registers_per_block_limit;
int64_t device_address_bits;
int64_t device_memory_size;
int64_t memory_bandwidth;
int64_t shared_memory_per_core;
int64_t shared_memory_per_block;
float clock_rate_ghz;
int cuda_compute_capability_major;
int cuda_compute_capability_minor;
int rocm_amdgpu_isa_version;
int numa_node;
int core_count;
bool ecc_enabled;
} SE_DeviceDescription;
typedef struct XLA_TransferManager XLA_TransferManager;
typedef struct XLA_ComputationPlacer XLA_ComputationPlacer;
// Represents an XLA shape tree.
// Shapes are flattened in default traversal order.
typedef struct XLA_Shape {
char* bytes;
size_t size;
} XLA_Shape;
// Represents a leaf node for a XLA shaped buffer.
typedef struct XLA_ShapedBuffer {
XLA_Shape on_host_shape;
XLA_Shape on_device_shape;
int device_ordinal;
SE_DeviceMemoryBase* bases;
size_t count;
} XLA_ShapedBuffer;
// Represents a leaf XLA literal.
typedef struct XLA_Literal {
char** buffers;
size_t* sizes;
size_t count;
XLA_Shape shape;
} XLA_Literal;
typedef void (*XLA_CallbackFn)(void*);
typedef void (*XLA_StatusCallbackFn)(void*, SE_Status*);
extern "C" {
SE_Platform* TpuPlatform_New();
void TpuPlatform_Free(SE_Platform* platform);
void TpuPlatform_Initialize(SE_Platform* platform, size_t options_size,
const char** options_key,
const char** options_value, SE_Status* status);
bool TpuPlatform_Initialized(SE_Platform* platform);
SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform,
SE_StreamExecutorConfig* config,
SE_Status* status);
SE_PlatformId TpuPlatform_Id(SE_Platform* platform);
int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform);
int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform);
void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal,
SE_DeviceOptions* device_options, SE_Status* status);
void TpuExecutor_Free(SE_StreamExecutor* executor);
int TpuExecutor_PlatformDeviceCount(SE_StreamExecutor* executor);
SE_DeviceMemoryBase TpuExecutor_Allocate(SE_StreamExecutor* executor,
uint64_t size, int64_t memory_space);
void TpuExecutor_Deallocate(SE_StreamExecutor* executor,
SE_DeviceMemoryBase* memory);
bool TpuExecutor_GetAllocatorStats(SE_StreamExecutor* executor,
SE_AllocatorStats* stats);
bool TpuExecutor_DeviceMemoryUsage(SE_StreamExecutor* executor, int64_t* free,
int64_t* total);
bool TpuExecutor_AllocateStream(SE_StreamExecutor* executor, SE_Stream* stream);
void TpuExecutor_DeallocateStream(SE_StreamExecutor* executor,
SE_Stream* stream);
bool TpuExecutor_CreateStreamDependency(SE_StreamExecutor* executor,
SE_Stream* dependent, SE_Stream* other);
void TpuExecutor_GetStatus(SE_StreamExecutor* executor, SE_Stream* stream,
SE_Status* status);
void TpuExecutor_AllocateEvent(SE_StreamExecutor* executor, SE_Event* event,
SE_Status* status);
void TpuExecutor_DeallocateEvent(SE_StreamExecutor* executor, SE_Event* event,
SE_Status* status);
int TpuExecutor_PollForEventStatus(SE_StreamExecutor* executor,
SE_Event* event);
void TpuExecutor_RecordEvent(SE_StreamExecutor* executor, SE_Stream* stream,
SE_Event* event, SE_Status* status);
void TpuExecutor_WaitForEvent(SE_StreamExecutor* executor, SE_Stream* stream,
SE_Event* event, SE_Status* status);
bool TpuExecutor_AllocateTimer(SE_StreamExecutor* executor, SE_Timer* timer);
void TpuExecutor_DeallocateTimer(SE_StreamExecutor* executor, SE_Timer* timer);
bool TpuExecutor_StartTimer(SE_StreamExecutor* executor, SE_Stream* stream,
SE_Timer* timer);
bool TpuExecutor_StopTimer(SE_StreamExecutor* executor, SE_Stream* stream,
SE_Timer* timer);
void TpuExecutor_SynchronousMemcpyToHost(SE_StreamExecutor* executor,
void* host_dst,
const SE_DeviceMemoryBase* device_src,
uint64_t size, SE_Status* status);
void TpuExecutor_SynchronousMemcpyFromHost(SE_StreamExecutor* executor,
SE_DeviceMemoryBase* device_dst,
const void* host_src, uint64_t size,
SE_Status* status);
bool TpuExecutor_MemcpyToHost(SE_StreamExecutor* executor, SE_Stream* stream,
void* host_dst,
const SE_DeviceMemoryBase* device_src,
uint64_t size);
bool TpuExecutor_MemcpyFromHost(SE_StreamExecutor* executor, SE_Stream* stream,
SE_DeviceMemoryBase* device_dst,
const void* host_src, uint64_t size);
void TpuExecutor_EnqueueInfeed(SE_StreamExecutor* executor,
int32_t infeed_queue_index, const uint8_t* data,
int64_t size, SE_Status* status);
void TpuExecutor_DequeueOutfeed(SE_StreamExecutor* executor,
int32_t outfeed_queue_index, uint8_t* data,
int64_t size, SE_Status* status);
void TpuExecutor_WaitForInfeedReady(SE_StreamExecutor* executor,
int32_t infeed_queue_index,
SE_Status* status);
void TpuExecutor_WaitForOutfeedReady(SE_StreamExecutor* executor,
int32_t outfeed_queue_index,
SE_Status* status);
void TpuExecutor_BlockHostUntilDone(SE_StreamExecutor* executor,
SE_Stream* stream, SE_Status* status);
void TpuExecutor_BlockUntilDoneOrFailed(SE_StreamExecutor* executor,
SE_Status* status);
void TpuExecutor_SyncAndForgetFailedStreams(SE_StreamExecutor* executor);
bool TpuExecutor_SynchronizeAllActivity(SE_StreamExecutor* executor);
SE_Stream* TpuStream_New(SE_StreamExecutor* parent);
void TpuStream_Free(SE_Stream*);
void* TpuStream_Stream(SE_Stream*);
bool TpuStream_Status(SE_Stream*);
SE_Event* TpuEvent_New(SE_StreamExecutor* parent);
void TpuEvent_Free(SE_Event*);
SE_Timer* TpuTimer_New(SE_StreamExecutor* parent);
void TpuTimer_Free(SE_Timer*);
int64_t TpuTimer_Nanoseconds(SE_Timer*);
int64_t TpuTimer_Microseconds(SE_Timer*);
SE_Status* TpuStatus_New();
SE_Status* TpuStatus_Create(int32_t code, const char* msg);
void TpuStatus_Free(SE_Status* status);
const char* TpuStatus_Message(SE_Status* status);
int TpuStatus_Code(SE_Status* status);
bool TpuStatus_Ok(SE_Status* status);
SE_StreamExecutorConfig* TpuStreamExecutorConfig_Default();
void TpuStreamExecutorConfig_SetOrdinal(SE_StreamExecutorConfig*, int ordinal);
void TpuStreamExecutorConfig_Free(SE_StreamExecutorConfig*);
SE_DeviceDescription* TpuDeviceDescription_New();
void TpuDeviceDescription_Free(SE_DeviceDescription* description);
void TpuExecutor_CreateDeviceDescription(SE_StreamExecutor* executor,
SE_DeviceDescription* description,
SE_Status* status);
SE_DeviceOptions* TpuExecutor_NewDeviceOptions(unsigned flags);
void TpuExecutor_FreeDeviceOptions(SE_DeviceOptions* options);
bool TpuExecutor_HostCallback(SE_StreamExecutor* executor, SE_Stream* stream,
SE_StatusCallbackFn callback_fn, void* ctx);
XLA_TransferManager* TpuTransferManager_New();
void TpuTransferManager_Free(XLA_TransferManager* manager);
SE_PlatformId TpuTransferManager_PlatformId(XLA_TransferManager* manager);
void TpuTransferManager_HostShapeToDeviceShape(XLA_TransferManager* manager,
XLA_Shape* host_shape,
XLA_Shape* device_shape);
void TpuTransferManager_TransferLiteralToDeviceAsync(
XLA_TransferManager* manager, SE_Stream* stream, XLA_Literal* literal,
XLA_ShapedBuffer* device_buffer, SE_Status* status);
void TpuTransferManager_TransferLiteralFromDevice(
XLA_TransferManager* manager, SE_Stream* stream,
XLA_ShapedBuffer* device_buffer, XLA_Literal* literal,
XLA_StatusCallbackFn callback, void* ctx);
int64_t TpuTransferManager_GetByteSizeRequirement(XLA_TransferManager* manager,
XLA_Shape* shape);
void TpuTransferManager_WriteSingleTupleIndexTable(
XLA_TransferManager* manager, SE_Stream* stream,
SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
SE_DeviceMemoryBase* region, SE_Status* status);
XLA_ComputationPlacer* TpuComputationPlacer_New();
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
}
// extern "C"
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/timer.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tpu {
class TpuCore;
} // namespace tpu
namespace tensorflow {
namespace tpu {
class TpuExecutorInterface
: public ::stream_executor::internal::StreamExecutorInterface {
public:
using Status = ::stream_executor::port::Status;
template <typename T>
using StatusOr = ::stream_executor::port::StatusOr<T>;
class TemporaryDeviceMemory {
public:
virtual ~TemporaryDeviceMemory() {}
virtual stream_executor::DeviceMemoryBase AsDeviceMemoryBase() const = 0;
};
virtual StatusOr<std::unique_ptr<TemporaryDeviceMemory>>
CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset,
int64 size) {
LOG(FATAL) << "Unimplemented.";
}
virtual const TpuPlatformInterface& platform() const {
LOG(FATAL) << "Unimplemented.";
}
virtual TpuPlatformInterface& platform() { LOG(FATAL) << "Unimplemented."; }
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_

View File

@ -0,0 +1,100 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
namespace tensorflow {
namespace tpu {
using stream_executor::port::Status;
using stream_executor::port::StatusOr;
/*static*/ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Initialize(
int device_ordinal) {
StatusHelper status;
XLA_TpuNodeContext* node_context =
TpuNodeContext_Create(device_ordinal, status.c_status);
if (!status.status().ok()) {
TpuNodeContext_Free(node_context);
return status.status();
}
return std::make_unique<TpuNodeContext>(device_ordinal, node_context);
}
TpuNodeContext::~TpuNodeContext() { TpuNodeContext_Free(node_context_); }
/* static */
Status TpuNodeContext::StopChipHeartbeats() {
StatusHelper status;
TpuNodeContext_StopChipHeartbeats(status.c_status);
return status.status();
}
/* static */
Status TpuNodeContext::CloseTpuHost() {
StatusHelper status;
TpuNodeContext_CloseTpuHost(status.c_status);
return status.status();
}
/* static */
tensorflow::tpu::TpuPlatformInterface* TpuNodeContext::platform() {
return TpuPlatformInterface::GetRegisteredPlatform();
}
/* static */
stream_executor::DeviceMemoryAllocator* TpuNodeContext::memory_allocator() {
static stream_executor::StreamExecutorMemoryAllocator* memory_allocator =
new stream_executor::StreamExecutorMemoryAllocator(
platform(),
xla::PlatformUtil::GetStreamExecutors(platform()).ValueOrDie());
return memory_allocator;
}
/* static */
xla::Backend* TpuNodeContext::backend() {
static xla::Backend* backend =
xla::Backend::CreateBackend(
xla::BackendOptions().set_platform(platform()))
.ValueOrDie()
.release();
return backend;
}
/* static */
StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
int device_ordinal) {
return backend()->BorrowStream(device_ordinal);
}
/* static */
StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
stream_executor::StreamExecutor* executor) {
return backend()->BorrowStream(executor);
}
/* static */
xla::TransferManager* TpuNodeContext::transfer_manager() {
return xla::TransferManager::GetForPlatform(platform()).ValueOrDie();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,89 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tensorflow {
namespace tpu {
class TpuNodeContext final {
public:
using Status = stream_executor::port::Status;
template <typename T>
using StatusOr = stream_executor::port::StatusOr<T>;
static StatusOr<std::unique_ptr<TpuNodeContext>> Initialize(
int device_ordinal);
explicit TpuNodeContext(int device_ordinal, XLA_TpuNodeContext* node_context)
: device_ordinal_(device_ordinal), node_context_(node_context) {
CHECK_NE(node_context, nullptr);
}
~TpuNodeContext();
TpuNodeContext(const TpuNodeContext&) = delete;
TpuNodeContext& operator=(const TpuNodeContext&) = delete;
static Status StopChipHeartbeats();
static Status CloseTpuHost();
static tensorflow::tpu::TpuPlatformInterface* platform();
static stream_executor::DeviceMemoryAllocator* memory_allocator();
static xla::TransferManager* transfer_manager();
static xla::Backend* backend();
static StatusOr<xla::StreamPool::Ptr> BorrowStream(int device_ordinal);
static StatusOr<xla::StreamPool::Ptr> BorrowStream(
stream_executor::StreamExecutor* executor);
stream_executor::StreamExecutor* stream_executor() {
LOG(FATAL) << "Not implemented yet.";
}
std::string tensor_core_location() { LOG(FATAL) << "Not implemented yet."; }
int index_on_host() { LOG(FATAL) << "Not implemented yet."; }
int device_ordinal() const { return device_ordinal_; }
private:
const int device_ordinal_;
XLA_TpuNodeContext* const node_context_;
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
SE_Status* status);
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
void TpuNodeContext_CloseTpuHost(SE_Status* status);
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_

View File

@ -0,0 +1,125 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
namespace tensorflow {
PLATFORM_DEFINE_ID(TpuPlatform::kId);
TpuPlatform* tpu_registered_platform = nullptr;
using Status = ::stream_executor::port::Status;
template <typename T>
using StatusOr = ::stream_executor::port::StatusOr<T>;
TpuPlatform::TpuPlatform() { platform_ = TpuPlatform_New(); }
TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
return tpu_registered_platform;
}
Status TpuPlatform::Initialize(
const std::map<std::string, std::string>& platform_options) {
StatusHelper status;
size_t options_size = platform_options.size();
const char** options_key =
static_cast<const char**>(malloc(sizeof(const char*) * options_size));
const char** options_value =
static_cast<const char**>(malloc(sizeof(const char*) * options_size));
size_t i = 0;
for (const auto& option : platform_options) {
options_key[i] = option.first.c_str();
options_value[i] = option.second.c_str();
i++;
}
TpuPlatform_Initialize(platform_, options_size, options_key, options_value,
status.c_status);
free(options_key);
free(options_value);
return status.status();
}
TpuPlatform::~TpuPlatform() { TpuPlatform_Free(platform_); }
int TpuPlatform::VisibleDeviceCount() const {
return TpuPlatform_VisibleDeviceCount(platform_);
}
StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
const ::stream_executor::StreamExecutorConfig& config) {
return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
}
StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
TpuPlatform::GetUncachedExecutor(
const ::stream_executor::StreamExecutorConfig& config) {
SE_StreamExecutorConfig* c_config = TpuStreamExecutorConfig_Default();
TpuStreamExecutorConfig_SetOrdinal(c_config, config.ordinal);
StatusHelper status;
SE_StreamExecutor* executor =
TpuPlatform_GetExecutor(platform_, c_config, status.c_status);
TpuStreamExecutorConfig_Free(c_config);
if (!status.ok()) {
return status.status();
}
return std::make_unique<stream_executor::StreamExecutor>(
this, absl::make_unique<tensorflow::TpuExecutor>(this, executor),
config.ordinal);
}
::stream_executor::Platform::Id TpuPlatform::id() const {
return TpuPlatform::kId;
}
const std::string& TpuPlatform::Name() const {
static std::string* name = new std::string(kName);
return *name;
}
int64 TpuPlatform::TpuMemoryLimit() {
return TpuPlatform_TpuMemoryLimit(platform_);
}
} // namespace tensorflow
void RegisterTpuPlatform() {
tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
std::unique_ptr<stream_executor::Platform> platform(
tensorflow::tpu_registered_platform);
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(platform)));
}
REGISTER_MODULE_INITIALIZER(tpu_platform, RegisterTpuPlatform());
// Note that module initialization sequencing is not supported in the
// open-source project, so this will be a no-op there.
REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
tpu_platform);

View File

@ -0,0 +1,121 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tensorflow {
class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
public:
using StreamMap =
absl::flat_hash_map<stream_executor::internal::StreamInterface*,
SE_Stream*>;
static const ::stream_executor::Platform::Id kId;
static constexpr char kName[] = "TPU";
using Status = ::stream_executor::port::Status;
template <typename T>
using StatusOr = ::stream_executor::port::StatusOr<T>;
TpuPlatform();
~TpuPlatform() override;
static TpuPlatform* GetRegisteredPlatform();
Id id() const override;
const std::string& Name() const override;
int VisibleDeviceCount() const override;
int64 TpuMemoryLimit() override;
bool Initialized() const override {
return TpuPlatform_Initialized(platform_);
}
Status Initialize(
const std::map<std::string, std::string>& platform_options) override;
Status Reset() override { return Reset(false); }
Status Reset(bool only_tear_down) override {
LOG(FATAL) << "Not yet implemented";
}
StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
DescriptionForDevice(int ordinal) const override {
LOG(FATAL) << "Not yet implemented";
}
StatusOr<::stream_executor::StreamExecutor*> ExecutorForDevice(
int ordinal) override {
stream_executor::StreamExecutorConfig config;
config.ordinal = ordinal;
return GetExecutor(config);
}
StatusOr<::stream_executor::StreamExecutor*>
ExecutorForDeviceWithPluginConfig(
int ordinal,
const ::stream_executor::PluginConfig& plugin_config) override {
stream_executor::StreamExecutorConfig config;
config.ordinal = ordinal;
config.plugin_config = plugin_config;
return GetExecutor(config);
}
StatusOr<::stream_executor::StreamExecutor*> GetExecutor(
const ::stream_executor::StreamExecutorConfig& config) override;
StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
GetUncachedExecutor(
const ::stream_executor::StreamExecutorConfig& config) override;
void RegisterTraceListener(
std::unique_ptr<stream_executor::TraceListener> listener) override {
LOG(FATAL) << "Not yet implemented";
}
void UnregisterTraceListener(
stream_executor::TraceListener* listener) override {
LOG(FATAL) << "Not yet implemented";
}
StreamMap* stream_map() { return &stream_map_; }
private:
SE_Platform* platform_;
stream_executor::ExecutorCache executor_cache_;
StreamMap stream_map_;
};
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_

View File

@ -0,0 +1,63 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
namespace tensorflow {
namespace tpu {
/* static */
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
// Prefer TpuPlatform if it's registered.
auto status_or_tpu_platform =
stream_executor::MultiPlatformManager::PlatformWithName("TPU");
if (status_or_tpu_platform.ok()) {
return static_cast<TpuPlatformInterface*>(
status_or_tpu_platform.ValueOrDie());
}
if (status_or_tpu_platform.status().code() != error::NOT_FOUND) {
LOG(WARNING) << "Error when getting the TPU platform: "
<< status_or_tpu_platform.status();
return nullptr;
}
// Use any other registered TPU platform.
auto status_or_other_tpu_platforms =
stream_executor::MultiPlatformManager::PlatformsWithFilter(
[](const stream_executor::Platform* platform) {
return dynamic_cast<const TpuPlatformInterface*>(platform) !=
nullptr;
});
if (!status_or_other_tpu_platforms.ok()) {
LOG(WARNING) << "Error when getting other TPU platforms: "
<< status_or_tpu_platform.status();
return nullptr;
}
auto other_tpu_platforms = status_or_other_tpu_platforms.ValueOrDie();
if (!other_tpu_platforms.empty()) {
LOG(WARNING) << other_tpu_platforms.size()
<< " TPU platforms registered, selecting "
<< other_tpu_platforms[0]->Name();
return static_cast<TpuPlatformInterface*>(other_tpu_platforms[0]);
}
LOG(WARNING) << "No TPU platform registered";
return nullptr;
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/platform.h"
namespace tensorflow {
namespace tpu {
class TpuPlatformInterface : public stream_executor::Platform {
public:
using Status = stream_executor::port::Status;
// Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are
// registered, finds the most suitable one. Returns nullptr if no TPU platform
// is registered or an error occurred.
static TpuPlatformInterface* GetRegisteredPlatform();
virtual Status Reset() { return Reset(false); }
virtual Status Reset(bool only_tear_down) = 0;
virtual int64 TpuMemoryLimit() = 0;
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
class TpuStream : public stream_executor::internal::StreamInterface {
public:
explicit TpuStream(SE_Stream* stream) : stream_(stream) {}
~TpuStream() override { TpuStream_Free(stream_); }
private:
SE_Stream* stream_;
};
class TpuEvent : public ::stream_executor::internal::EventInterface {
public:
explicit TpuEvent(SE_Event* event) : event_(event) {}
~TpuEvent() override { TpuEvent_Free(event_); }
private:
SE_Event* event_;
};
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
#include "tensorflow/stream_executor/stream_executor_internal.h"
namespace tensorflow {
namespace tpu {
class TpuStreamInterface : public ::stream_executor::internal::StreamInterface {
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
namespace tensorflow {
class TpuTimer : public ::stream_executor::internal::TimerInterface {
public:
explicit TpuTimer(SE_Timer* timer) : timer_(timer) {}
~TpuTimer() override { TpuTimer_Free(timer_); }
uint64 Microseconds() const override { return TpuTimer_Microseconds(timer_); }
uint64 Nanoseconds() const override { return TpuTimer_Nanoseconds(timer_); }
private:
SE_Timer* timer_;
};
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_

View File

@ -0,0 +1,167 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
namespace tensorflow {
using Status = stream_executor::port::Status;
template <typename T>
using StatusOr = stream_executor::port::StatusOr<T>;
TpuTransferManager::TpuTransferManager() {
manager_ = TpuTransferManager_New();
}
TpuTransferManager::~TpuTransferManager() { TpuTransferManager_Free(manager_); }
stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
return TpuPlatform::kId;
}
xla::Shape TpuTransferManager::HostShapeToDeviceShape(
const xla::Shape& host_shape) const {
XLA_Shape c_host_shape;
XLA_Shape c_device_shape;
TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
TpuTransferManager_HostShapeToDeviceShape(manager_, &c_host_shape,
&c_device_shape);
xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
TpuConversions::CShapeCleanup(&c_host_shape);
TpuConversions::CShapeCleanup(&c_device_shape);
return device_shape;
}
Status TpuTransferManager::TransferLiteralToDeviceAsync(
stream_executor::Stream* stream, const xla::LiteralSlice& literal,
const xla::ShapedBuffer& device_buffer,
const TransferMetadata* transfer_metadata) {
StatusHelper status;
XLA_Literal c_literal;
TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
XLA_ShapedBuffer c_device_buffer;
TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
&c_device_buffer);
TpuTransferManager_TransferLiteralToDeviceAsync(
manager_,
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
stream->implementation()),
&c_literal, &c_device_buffer, status.c_status);
TpuConversions::CShapedBufferCleanup(&c_device_buffer);
TpuConversions::CLiteralCleanup(&c_literal);
return status.status();
}
struct TransferFromDeviceState {
std::atomic<int64_t> remaining_transfers;
StatusHelper status_helper;
std::function<void(Status)> done;
void TransferFinished(SE_Status* status) {
if (!TpuStatus_Ok(status) && TpuStatus_Ok(status_helper.c_status)) {
status_helper.c_status = status;
} else {
TpuStatus_Free(status);
}
if (--remaining_transfers == 0) {
done(status_helper.status());
delete this;
}
}
};
void TransferLiteralFromDeviceTrampoline(void* ctx, SE_Status* status) {
reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
}
void TpuTransferManager::TransferLiteralFromDevice(
stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
const TransferMetadata* transfer_metadata) {
TransferFromDeviceState* state = new TransferFromDeviceState;
state->remaining_transfers = 1;
state->done = done;
XLA_ShapedBuffer c_device_buffer;
TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
&c_device_buffer);
XLA_Literal c_literal;
TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
TpuTransferManager_TransferLiteralFromDevice(
manager_,
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
stream->implementation()),
&c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state);
TpuConversions::CShapedBufferCleanup(&c_device_buffer);
TpuConversions::CLiteralCleanup(&c_literal);
}
int64 TpuTransferManager::GetByteSizeRequirement(
const xla::Shape& shape) const {
XLA_Shape c_shape;
TpuConversions::XlaShapeToCShape(shape, &c_shape);
int64 size_in_bytes =
TpuTransferManager_GetByteSizeRequirement(manager_, &c_shape);
TpuConversions::CShapeCleanup(&c_shape);
return size_in_bytes;
}
Status TpuTransferManager::WriteSingleTupleIndexTable(
stream_executor::Stream* stream,
absl::Span<const stream_executor::DeviceMemoryBase> elements,
const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) {
CHECK_GT(elements.size(), 0);
SE_DeviceMemoryBase* elements_bases =
new SE_DeviceMemoryBase[elements.size()];
for (int i = 0; i < elements.size(); i++) {
elements_bases[i] =
SE_DeviceMemoryBase{const_cast<void*>(elements[i].opaque()),
elements[i].size(), elements[i].payload()};
}
XLA_Shape c_shape;
TpuConversions::XlaShapeToCShape(shape, &c_shape);
SE_DeviceMemoryBase region_base{region->opaque(), region->size(),
region->payload()};
StatusHelper status;
TpuTransferManager_WriteSingleTupleIndexTable(
manager_,
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
stream->implementation()),
elements_bases, elements.size(), &c_shape, &region_base, status.c_status);
delete[] elements_bases;
TpuConversions::CShapeCleanup(&c_shape);
return status.status();
}
} // namespace tensorflow

View File

@ -0,0 +1,83 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
namespace tensorflow {
class TpuTransferManager : public xla::TransferManager {
public:
TpuTransferManager();
~TpuTransferManager() override;
using Status = stream_executor::port::Status;
template <typename T>
using StatusOr = stream_executor::port::StatusOr<T>;
stream_executor::Platform::Id PlatformId() const override;
xla::Shape HostShapeToDeviceShape(
const xla::Shape& host_shape) const override;
Status TransferLiteralToDeviceAsync(
stream_executor::Stream* stream, const xla::LiteralSlice& literal,
const xla::ShapedBuffer& device_buffer,
const TransferMetadata* transfer_metadata) override;
void TransferLiteralFromDevice(
stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
const TransferMetadata* transfer_metadata) override;
Status TransferLiteralToInfeed(stream_executor::StreamExecutor* executor,
const xla::LiteralSlice& literal) override {
LOG(FATAL) << "Not yet implemented";
}
Status TransferLiteralFromOutfeed(
stream_executor::StreamExecutor* executor,
const xla::Shape& literal_shape,
xla::MutableBorrowingLiteral literal) override {
LOG(FATAL) << "Not yet implemented";
}
Status ResetDevices(
absl::Span<stream_executor::StreamExecutor* const> executor) override {
LOG(FATAL) << "Not yet implemented";
}
int64 GetByteSizeRequirement(const xla::Shape& shape) const override;
Status WriteSingleTupleIndexTable(
stream_executor::Stream* stream,
absl::Span<const stream_executor::DeviceMemoryBase> elements,
const xla::Shape& shape,
stream_executor::DeviceMemoryBase* region) override;
private:
XLA_TransferManager* manager_;
};
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
namespace tensorflow {
static std::unique_ptr<xla::TransferManager> CreateTpuTransferManager() {
return std::make_unique<TpuTransferManager>();
}
static bool InitModule() {
xla::TransferManager::RegisterTransferManager(TpuPlatform::kId,
CreateTpuTransferManager);
return true;
}
static bool module_initialized = InitModule();
} // namespace tensorflow