Open sourcing some TPU-related work
PiperOrigin-RevId: 315431095 Change-Id: I734632c0e5723dfca37acf53bbbd2b378b04c95d
This commit is contained in:
parent
d9e5e2f7b3
commit
13c09da422
55
tensorflow/core/tpu/graph_rewrite/BUILD
Normal file
55
tensorflow/core/tpu/graph_rewrite/BUILD
Normal file
@ -0,0 +1,55 @@
|
||||
# Contains graph rewrites for TPU runtimes and optimizations.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/core/tpu:__subpackages__",
|
||||
"//tensorflow/stream_executor/tpu:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "distributed_tpu_configuration_rewrite_registration",
|
||||
srcs = ["distributed_tpu_configuration_rewrite_registration.cc"],
|
||||
deps = [
|
||||
":distributed_tpu_configuration_rewrite_pass",
|
||||
"//tensorflow/core:core_cpu",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "distributed_tpu_configuration_rewrite_pass",
|
||||
srcs = [
|
||||
"distributed_tpu_configuration_rewrite_pass.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"distributed_tpu_configuration_rewrite_pass.h",
|
||||
],
|
||||
deps = [
|
||||
":distributed_tpu_rewrite_helpers",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_init_mode",
|
||||
"//tensorflow/core/tpu/kernels:tpu_compile_op_options",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "distributed_tpu_rewrite_helpers",
|
||||
srcs = ["distributed_tpu_rewrite_helpers.cc"],
|
||||
hdrs = ["distributed_tpu_rewrite_helpers.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
],
|
||||
)
|
@ -0,0 +1,402 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Configuration for distributed TPU jobs
|
||||
|
||||
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
|
||||
#include "tensorflow/core/tpu/tpu_init_mode.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kIdentityOp[] = "Identity";
|
||||
constexpr char kConfigureOp[] = "ConfigureDistributedTPU";
|
||||
constexpr char kInternalConfigureOp[] = "_ConfigureDistributedTPU";
|
||||
constexpr char kWaitOp[] = "_WaitForDistributedTPU";
|
||||
constexpr char kHostConfigureOp[] = "_InitializeHostForDistributedTPU";
|
||||
constexpr char kGlobalTPUArrayOp[] = "_SetGlobalTPUArray";
|
||||
constexpr char kShutdownOp[] = "ShutdownDistributedTPU";
|
||||
constexpr char kInternalShutdownOp[] = "_ShutdownDistributedTPU";
|
||||
constexpr char kHostDisconnectOp[] = "_DisconnectHostFromDistributedTPUSystem";
|
||||
constexpr char kEmbeddingConfigurationAttr[] = "embedding_config";
|
||||
constexpr int kDefaultStartupTimeout = 20;
|
||||
|
||||
Status AddConfigurationNode(const string& configuration_device_name,
|
||||
int number_of_hosts, Graph* graph,
|
||||
bool enable_whole_mesh_compilations,
|
||||
Node** configuration_node) {
|
||||
NodeDef config_def;
|
||||
config_def.set_name(graph->NewName("configure_distributed_tpu"));
|
||||
config_def.set_op(kInternalConfigureOp);
|
||||
config_def.set_device(configuration_device_name);
|
||||
AddNodeAttr("N", number_of_hosts, &config_def);
|
||||
AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
|
||||
&config_def);
|
||||
// TODO(shikharagarwal): Fill with appropriate original node debug info.
|
||||
|
||||
Status status;
|
||||
*configuration_node = graph->AddNode(config_def, &status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
(*configuration_node)->set_assigned_device_name(configuration_device_name);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddHostConfigNode(const string& host_device_name,
|
||||
Node* configuration_node, Graph* graph,
|
||||
bool enable_whole_mesh_compilations,
|
||||
Node** host_configuration_node) {
|
||||
NodeDef host_config_def;
|
||||
host_config_def.set_name(graph->NewName("configure_tpu_host"));
|
||||
host_config_def.set_op(kHostConfigureOp);
|
||||
host_config_def.set_device(host_device_name);
|
||||
AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations,
|
||||
&host_config_def);
|
||||
MergeDebugInfo(NodeDebugInfo(configuration_node->def()), &host_config_def);
|
||||
|
||||
Status status;
|
||||
*host_configuration_node = graph->AddNode(host_config_def, &status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
(*host_configuration_node)->set_assigned_device_name(host_device_name);
|
||||
graph->AddEdge(configuration_node, 0, *host_configuration_node, 0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddWaitNode(const string& configuration_device_name,
|
||||
const std::vector<Node*>& host_configuration_nodes,
|
||||
Graph* graph, Node** wait_node) {
|
||||
NodeDef wait_def;
|
||||
wait_def.set_name(graph->NewName("wait_for_distributed_tpu_system"));
|
||||
wait_def.set_op(kWaitOp);
|
||||
wait_def.set_device(configuration_device_name);
|
||||
AddNodeAttr("N", static_cast<int32>(host_configuration_nodes.size()),
|
||||
&wait_def);
|
||||
AddNodeAttr("startup_timeout_sec", kDefaultStartupTimeout, &wait_def);
|
||||
if (!host_configuration_nodes.empty()) {
|
||||
MergeDebugInfo(NodeDebugInfo(host_configuration_nodes[0]->def()),
|
||||
&wait_def);
|
||||
}
|
||||
|
||||
Status status;
|
||||
*wait_node = graph->AddNode(wait_def, &status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
(*wait_node)->set_assigned_device_name(configuration_device_name);
|
||||
// Get the inputs from the host configuration nodes.
|
||||
for (int i = 0; i < host_configuration_nodes.size(); ++i) {
|
||||
graph->AddEdge(host_configuration_nodes[i], 0, *wait_node, i);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddGlobalTPUArrayNode(const string& host_device_name, Node* wait_node,
|
||||
Graph* graph, Node** global_tpu_array_node) {
|
||||
NodeDef global_tpu_array_def;
|
||||
global_tpu_array_def.set_name(graph->NewName("set_global_tpu_array"));
|
||||
global_tpu_array_def.set_op(kGlobalTPUArrayOp);
|
||||
global_tpu_array_def.set_device(host_device_name);
|
||||
MergeDebugInfo(NodeDebugInfo(wait_node->def()), &global_tpu_array_def);
|
||||
|
||||
Status status;
|
||||
*global_tpu_array_node = graph->AddNode(global_tpu_array_def, &status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
(*global_tpu_array_node)->set_assigned_device_name(host_device_name);
|
||||
graph->AddEdge(wait_node, 0, *global_tpu_array_node, 0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddSynchronizationNode(
|
||||
const NodeDef& sync_node_def, const string& device_name,
|
||||
const std::vector<Node*>& global_array_id_nodes, Node* wait_node,
|
||||
const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
|
||||
output_dependencies,
|
||||
Graph* graph) {
|
||||
NodeDef sync_def;
|
||||
sync_def.set_name(sync_node_def.name());
|
||||
sync_def.set_op(kIdentityOp);
|
||||
sync_def.set_device(device_name);
|
||||
AddNodeAttr("T", DT_STRING, &sync_def);
|
||||
MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def);
|
||||
|
||||
Status status;
|
||||
Node* sync_node = graph->AddNode(sync_def, &status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
sync_node->set_assigned_device_name(device_name);
|
||||
// Add control edges from the global array id nodes.
|
||||
for (auto node : global_array_id_nodes) {
|
||||
graph->AddControlEdge(node, sync_node);
|
||||
}
|
||||
// Forward the data from the wait node.
|
||||
graph->AddEdge(wait_node, 0, sync_node, 0);
|
||||
// Replace the output edges.
|
||||
for (const DistributedTPURewriteHelpers::OutputDependency& dep :
|
||||
output_dependencies) {
|
||||
if (dep.dst_input == Graph::kControlSlot) {
|
||||
graph->AddControlEdge(sync_node, dep.dst);
|
||||
} else {
|
||||
graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
Status AddShutdownNode(
|
||||
const NodeDef& shutdown_node_def, const string& shutdown_device_name,
|
||||
const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
|
||||
output_dependencies,
|
||||
Graph* graph, Node** shutdown_node) {
|
||||
NodeDef shutdown_def;
|
||||
shutdown_def.set_name(shutdown_node_def.name());
|
||||
shutdown_def.set_op(kInternalShutdownOp);
|
||||
shutdown_def.set_device(shutdown_device_name);
|
||||
MergeDebugInfo(NodeDebugInfo(shutdown_node_def), &shutdown_def);
|
||||
|
||||
Status status;
|
||||
*shutdown_node = graph->AddNode(shutdown_def, &status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
(*shutdown_node)->set_assigned_device_name(shutdown_device_name);
|
||||
// Replace the output control edges.
|
||||
for (const DistributedTPURewriteHelpers::OutputDependency& dep :
|
||||
output_dependencies) {
|
||||
if (dep.dst_input != Graph::kControlSlot) {
|
||||
return errors::Internal("Shutdown node had non-control edge output");
|
||||
}
|
||||
graph->AddControlEdge(*shutdown_node, dep.dst);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddHostDisconnectNode(const string& host_device_name,
|
||||
const std::vector<Node*>& input_dependencies,
|
||||
Node* post_disconnect_node, int output_index,
|
||||
Graph* graph) {
|
||||
NodeDef host_disconnect_def;
|
||||
host_disconnect_def.set_name(graph->NewName("disconnect_tpu_host"));
|
||||
host_disconnect_def.set_op(kHostDisconnectOp);
|
||||
host_disconnect_def.set_device(host_device_name);
|
||||
MergeDebugInfo(NodeDebugInfo(post_disconnect_node->def()),
|
||||
&host_disconnect_def);
|
||||
|
||||
Status status;
|
||||
Node* host_disconnect_node = graph->AddNode(host_disconnect_def, &status);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
host_disconnect_node->set_assigned_device_name(host_device_name);
|
||||
// Replace the input control edges.
|
||||
for (Node* src_node : input_dependencies) {
|
||||
graph->AddControlEdge(src_node, host_disconnect_node);
|
||||
}
|
||||
if (output_index == -1) {
|
||||
graph->AddControlEdge(host_disconnect_node, post_disconnect_node);
|
||||
} else {
|
||||
graph->AddEdge(host_disconnect_node, 0, post_disconnect_node, output_index);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status DistributedTPUConfigurationRewritePass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
VLOG(1) << "DistributedTPUConfigurationRewritePass::Run";
|
||||
|
||||
Graph* graph = options.graph->get();
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpGraphToFile("distributed_tpu_configuration_before", *graph,
|
||||
options.flib_def);
|
||||
}
|
||||
|
||||
// This pass can only run in the session master, which should fill
|
||||
// in the device_set field to the options.
|
||||
TF_RET_CHECK(options.device_set != nullptr);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
|
||||
kConfigureOp, graph, *options.device_set,
|
||||
[](const NodeDef& configuration_node_def,
|
||||
const string& configuration_device_name,
|
||||
const std::vector<Device*>& host_devices,
|
||||
const std::vector<Node*>& input_dependencies,
|
||||
const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
|
||||
output_dependencies,
|
||||
Graph* graph) -> Status {
|
||||
const std::string& embedding_attr_string = GetNodeAttrString(
|
||||
AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr);
|
||||
|
||||
if (!embedding_attr_string.empty()) {
|
||||
return errors::InvalidArgument("embedding_config must be empty.");
|
||||
}
|
||||
|
||||
bool is_global_init = false;
|
||||
bool enable_whole_mesh_compilations = false;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
|
||||
"is_global_init", &is_global_init));
|
||||
TryGetNodeAttr(configuration_node_def,
|
||||
"enable_whole_mesh_compilations",
|
||||
&enable_whole_mesh_compilations);
|
||||
TF_RETURN_IF_ERROR(SetTPUInitMode(
|
||||
is_global_init ? TPUInitMode::kGlobal : TPUInitMode::kRegular));
|
||||
|
||||
bool compilation_failure_closes_chips;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def,
|
||||
"compilation_failure_closes_chips",
|
||||
&compilation_failure_closes_chips));
|
||||
internal::SetTpuCompilationFailureClosesChips(
|
||||
compilation_failure_closes_chips);
|
||||
|
||||
// Add the global TPU system configuration node.
|
||||
Node* configuration_node;
|
||||
TF_RETURN_IF_ERROR(AddConfigurationNode(
|
||||
configuration_device_name, host_devices.size(), graph,
|
||||
enable_whole_mesh_compilations, &configuration_node));
|
||||
|
||||
// Add the host disconnect nodes.
|
||||
for (int i = 0; i < host_devices.size(); ++i) {
|
||||
const auto host_device = host_devices[i];
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddHostDisconnectNode(host_device->name(), input_dependencies,
|
||||
configuration_node, i, graph));
|
||||
}
|
||||
|
||||
// Add the host configuration nodes.
|
||||
std::vector<Node*> host_configuration_nodes;
|
||||
for (const auto host_device : host_devices) {
|
||||
Node* host_configuration_node;
|
||||
TF_RETURN_IF_ERROR(AddHostConfigNode(
|
||||
host_device->name(), configuration_node, graph,
|
||||
enable_whole_mesh_compilations, &host_configuration_node));
|
||||
host_configuration_nodes.push_back(host_configuration_node);
|
||||
}
|
||||
|
||||
// Add the node to wait for the system configuration to
|
||||
// stabilize. Use the name of the original dummy Op in case it was
|
||||
// the target of a Session::Run call.
|
||||
Node* wait_node;
|
||||
TF_RETURN_IF_ERROR(AddWaitNode(configuration_device_name,
|
||||
host_configuration_nodes, graph,
|
||||
&wait_node));
|
||||
|
||||
// Add the nodes to set the global TPU ids at each host.
|
||||
std::vector<Node*> global_array_id_nodes;
|
||||
for (const auto host_device : host_devices) {
|
||||
Node* global_array_id_node;
|
||||
TF_RETURN_IF_ERROR(AddGlobalTPUArrayNode(host_device->name(),
|
||||
wait_node, graph,
|
||||
&global_array_id_node));
|
||||
global_array_id_nodes.push_back(global_array_id_node);
|
||||
}
|
||||
|
||||
if (host_devices.empty()) {
|
||||
return errors::InvalidArgument("TPU job contains no CPU devices");
|
||||
}
|
||||
TF_RET_CHECK(!host_devices.empty());
|
||||
|
||||
TF_RETURN_IF_ERROR(AddSynchronizationNode(
|
||||
configuration_node_def, host_devices.front()->name(),
|
||||
global_array_id_nodes, wait_node, output_dependencies, graph));
|
||||
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpGraphToFile("distributed_tpu_configuration_after", *graph,
|
||||
options.flib_def);
|
||||
}
|
||||
|
||||
VLOG(1) << "DistributedTPUConfigurationRewritePass::Run() finished";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DistributedTPUShutdownRewritePass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
VLOG(1) << "DistributedTPUShutdownRewritePass::Run";
|
||||
|
||||
Graph* graph = options.graph->get();
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpGraphToFile("distributed_tpu_shutdown_before", *graph,
|
||||
options.flib_def);
|
||||
}
|
||||
|
||||
// This pass can only run in the session master, which should fill
|
||||
// in the device_set field to the options.
|
||||
TF_RET_CHECK(options.device_set != nullptr);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
|
||||
kShutdownOp, graph, *options.device_set,
|
||||
[](const NodeDef& shutdown_node_def,
|
||||
const string& shutdown_device_name,
|
||||
const std::vector<Device*>& host_devices,
|
||||
const std::vector<Node*>& input_dependencies,
|
||||
const std::vector<DistributedTPURewriteHelpers::OutputDependency>&
|
||||
output_dependencies,
|
||||
Graph* graph) -> Status {
|
||||
Node* shutdown_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddShutdownNode(shutdown_node_def, shutdown_device_name,
|
||||
output_dependencies, graph, &shutdown_node));
|
||||
|
||||
// Add the host disconnect nodes.
|
||||
for (const auto host_device : host_devices) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddHostDisconnectNode(host_device->name(), input_dependencies,
|
||||
shutdown_node, -1, graph));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
DumpGraphToFile("distributed_tpu_shutdown_after", *graph, options.flib_def);
|
||||
}
|
||||
|
||||
VLOG(1) << "DistributedTPUShutdownRewritePass::Run() finished";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Rewrites ConfigureDistributedTPU Op into a graph that configures each host.
|
||||
//
|
||||
// See the comment at the top of
|
||||
// third_party/tensorflow/core/ops/tpu_configuration_ops.cc to see the
|
||||
// sequence of Ops used to configure a distributed TPU system.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
|
||||
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Replaces dummy ConfigureDistributedTPU Ops assigned to TPU_SYSTEM
|
||||
// devices with _ConfigureDistributedTPU and _WaitForDistributedTPU
|
||||
// Ops on TPU_SYSTEM, and _InitializeHostForDistributedTPU on the CPU
|
||||
// device of each host in the same job as the given TPU_SYSTEM device.
|
||||
class DistributedTPUConfigurationRewritePass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
// Replaces dummy ShutdownDistributedTPU Ops assigned to TPU_SYSTEM
|
||||
// devices with _ShutdownDistributedTPU Ops on TPU_SYSTEM and
|
||||
// _DisconnectHostFromDistributedTPUSystem on the CPU device of each
|
||||
// host in the same job as the given TPU_SYSTEM device.
|
||||
class DistributedTPUShutdownRewritePass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_
|
@ -0,0 +1,29 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// This pass removes the TPUEmbeddingConfiguration in ConfigureDistributedTPU.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20,
|
||||
DistributedTPUConfigurationRewritePass);
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20,
|
||||
DistributedTPUShutdownRewritePass);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -0,0 +1,255 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Helper functions for TPU rewrite passes.
|
||||
|
||||
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// LINT.IfChange
|
||||
Status DistributedTPURewriteHelpers::GetSystemDevice(
|
||||
const string& system_spec_string, const DeviceSet& device_set,
|
||||
DeviceNameUtils::ParsedName* system_spec, Device** system_device) {
|
||||
if (!DeviceNameUtils::ParseFullName(system_spec_string, system_spec)) {
|
||||
system_spec->Clear();
|
||||
}
|
||||
|
||||
// Callers may have relied on an Op only being registered on TPU_SYSTEM
|
||||
// devices to ensure the Op is placed there. Augment the device spec to make
|
||||
// the device type explicit.
|
||||
if (!system_spec->has_type || system_spec->type != DEVICE_TPU_SYSTEM) {
|
||||
system_spec->type = DEVICE_TPU_SYSTEM;
|
||||
system_spec->has_type = true;
|
||||
system_spec->id = 0;
|
||||
system_spec->has_id = true;
|
||||
}
|
||||
|
||||
std::vector<Device*> system_devices;
|
||||
device_set.FindMatchingDevices(*system_spec, &system_devices);
|
||||
if (system_devices.empty()) {
|
||||
if (system_spec_string.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"No TPU_SYSTEM device found. Please ensure that you're connected to "
|
||||
"a host with a TPU_SYSTEM device.");
|
||||
}
|
||||
return errors::InvalidArgument("No matching devices found for '",
|
||||
system_spec_string, "'");
|
||||
} else if (system_devices.size() > 1) {
|
||||
// Validate that all system devices are part of the same job.
|
||||
std::unordered_set<string> job_names;
|
||||
for (auto device : system_devices) {
|
||||
const auto& parsed_name = device->parsed_name();
|
||||
TF_RET_CHECK(parsed_name.has_job);
|
||||
job_names.insert(parsed_name.job);
|
||||
}
|
||||
if (job_names.size() > 1) {
|
||||
return errors::InvalidArgument(
|
||||
"System devices cannot be part "
|
||||
"of multiple different jobs. Found: ",
|
||||
str_util::Join(job_names, ","));
|
||||
}
|
||||
|
||||
// Identify the lexicographically first device from the list of
|
||||
// valid TPU SYSTEM devices, so that every process in the same
|
||||
// 'cluster' definition uses the same system device.
|
||||
std::sort(system_devices.begin(), system_devices.end(),
|
||||
[](Device* i, Device* j) {
|
||||
auto i_name = i->parsed_name();
|
||||
auto j_name = j->parsed_name();
|
||||
if (i_name.replica != j_name.replica) {
|
||||
return i_name.replica < j_name.replica;
|
||||
}
|
||||
return i_name.task < j_name.task;
|
||||
});
|
||||
}
|
||||
|
||||
*system_device = system_devices[0];
|
||||
if (!DeviceNameUtils::ParseFullName((*system_device)->name(), system_spec)) {
|
||||
return errors::InvalidArgument("Unable to re-parse system device name ",
|
||||
(*system_device)->name(),
|
||||
" as a device spec.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
|
||||
|
||||
// LINT.IfChange
|
||||
Status DistributedTPURewriteHelpers::GetHostSystemDevices(
|
||||
const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
|
||||
std::vector<Device*>* host_system_devices) {
|
||||
DeviceNameUtils::ParsedName host_spec;
|
||||
if (system_spec.has_job) {
|
||||
// The system Op has been explicitly assigned to a job, so we want
|
||||
// all the hosts in that job.
|
||||
CHECK(DeviceNameUtils::ParseFullName(
|
||||
strings::StrCat("/job:", system_spec.job, "/device:", DEVICE_TPU_SYSTEM,
|
||||
":0"),
|
||||
&host_spec));
|
||||
} else {
|
||||
// The system Op has not been explicitly assigned to a
|
||||
// job, so take all hosts in the system. There will be a runtime
|
||||
// error if some of those hosts don't contain TPU devices.
|
||||
CHECK(DeviceNameUtils::ParseFullName(
|
||||
strings::StrCat("/device:", DEVICE_TPU_SYSTEM, ":0"), &host_spec));
|
||||
}
|
||||
device_set.FindMatchingDevices(host_spec, host_system_devices);
|
||||
|
||||
TF_RET_CHECK(!host_system_devices->empty())
|
||||
<< "No hosts found matching device spec "
|
||||
<< DeviceNameUtils::ParsedNameToString(host_spec);
|
||||
|
||||
// Check that all the devices belong to the same job.
|
||||
TF_RET_CHECK((*host_system_devices)[0]->parsed_name().has_job);
|
||||
const string& job_name = (*host_system_devices)[0]->parsed_name().job;
|
||||
int replica = (*host_system_devices)[0]->parsed_name().replica;
|
||||
for (const auto host_device : *host_system_devices) {
|
||||
const auto& parsed_name = host_device->parsed_name();
|
||||
TF_RET_CHECK(parsed_name.has_job);
|
||||
if (parsed_name.job != job_name) {
|
||||
return errors::InvalidArgument(
|
||||
"All TPU host devices must be in the same job");
|
||||
}
|
||||
TF_RET_CHECK(parsed_name.has_replica);
|
||||
if (parsed_name.replica != replica) {
|
||||
return errors::InvalidArgument(
|
||||
"All TPU host devices must be in the same replica");
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the devices by replica and then task.
|
||||
std::sort(host_system_devices->begin(), host_system_devices->end(),
|
||||
[](Device* i, Device* j) {
|
||||
auto i_name = i->parsed_name();
|
||||
auto j_name = j->parsed_name();
|
||||
return i_name.task < j_name.task;
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
|
||||
|
||||
// LINT.IfChange
|
||||
Status DistributedTPURewriteHelpers::GetTPUDevices(
|
||||
const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set,
|
||||
int* num_tpus_per_host, std::vector<std::vector<Device*>>* tpu_devices) {
|
||||
// GetHostSystemDevices returns the CPU device on each host that is
|
||||
// going to be used for executing TPU code.
|
||||
std::vector<Device*> host_system_devices;
|
||||
TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetHostSystemDevices(
|
||||
system_spec, device_set, &host_system_devices));
|
||||
|
||||
// Enumerate all the physical devices. Enumerate devices on task 0,
|
||||
// then task 1, etc.
|
||||
std::sort(host_system_devices.begin(), host_system_devices.end(),
|
||||
[](Device* i, Device* j) {
|
||||
return i->parsed_name().task < j->parsed_name().task;
|
||||
});
|
||||
|
||||
*num_tpus_per_host = 0;
|
||||
tpu_devices->clear();
|
||||
tpu_devices->reserve(host_system_devices.size());
|
||||
for (const auto device : host_system_devices) {
|
||||
// Make a copy of the parsed name because we are going to change it.
|
||||
DeviceNameUtils::ParsedName device_spec = device->parsed_name();
|
||||
device_spec.has_type = true;
|
||||
device_spec.type = "TPU";
|
||||
// Enumerate all the available TPUs.
|
||||
device_spec.has_id = false;
|
||||
std::vector<Device*> host_tpu_devices;
|
||||
device_set.FindMatchingDevices(device_spec, &host_tpu_devices);
|
||||
// Sort the devices by device id.
|
||||
std::sort(host_tpu_devices.begin(), host_tpu_devices.end(),
|
||||
[](Device* i, Device* j) {
|
||||
return i->parsed_name().id < j->parsed_name().id;
|
||||
});
|
||||
if (tpu_devices->empty()) {
|
||||
// First iteration: set *num_tpus_per_host to the number of TPUs on the
|
||||
// first host.
|
||||
*num_tpus_per_host = host_tpu_devices.size();
|
||||
} else if (*num_tpus_per_host != host_tpu_devices.size()) {
|
||||
// Subsequent iterations: check the number of TPUs match the number on
|
||||
// the first host.
|
||||
return errors::InvalidArgument(
|
||||
"Mismatched number of TPU devices in cluster ", *num_tpus_per_host,
|
||||
" vs. ", host_tpu_devices.size());
|
||||
}
|
||||
tpu_devices->push_back(std::move(host_tpu_devices));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc)
|
||||
|
||||
Status DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType(
|
||||
const string& node_type, Graph* graph, const DeviceSet& device_set,
|
||||
const std::function<
|
||||
Status(const NodeDef& configuration_node_def,
|
||||
const string& configuration_device_name,
|
||||
const std::vector<Device*>& host_devices,
|
||||
const std::vector<Node*>& input_dependencies,
|
||||
const std::vector<OutputDependency>& output_dependencies,
|
||||
Graph* graph)>& action) {
|
||||
// Find all the matching nodes before mutating the graph.
|
||||
std::vector<Node*> nodes;
|
||||
for (Node* node : graph->nodes()) {
|
||||
if (node->type_string() == node_type) {
|
||||
nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* node : nodes) {
|
||||
string spec_string = node->requested_device();
|
||||
DeviceNameUtils::ParsedName spec;
|
||||
Device* device;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetSystemDevice(spec_string, device_set, &spec, &device));
|
||||
const string& device_name = device->name();
|
||||
|
||||
std::vector<Device*> host_devices;
|
||||
TF_RETURN_IF_ERROR(GetHostSystemDevices(spec, device_set, &host_devices));
|
||||
|
||||
std::vector<Node*> input_dependencies;
|
||||
for (const Edge* edge : node->in_edges()) {
|
||||
// Config ops have no inputs, so all edges must be control edges.
|
||||
CHECK(edge->IsControlEdge());
|
||||
input_dependencies.push_back(edge->src());
|
||||
}
|
||||
std::vector<OutputDependency> output_dependencies;
|
||||
for (const Edge* edge : node->out_edges()) {
|
||||
OutputDependency dep;
|
||||
dep.src_output = edge->src_output();
|
||||
dep.dst = edge->dst();
|
||||
dep.dst_input = edge->dst_input();
|
||||
output_dependencies.push_back(dep);
|
||||
}
|
||||
NodeDef node_def = node->def();
|
||||
|
||||
// Remove the node now so we can insert a new node with the same
|
||||
// name inside the action.
|
||||
graph->RemoveNode(node);
|
||||
|
||||
TF_RETURN_IF_ERROR(action(node_def, device_name, host_devices,
|
||||
input_dependencies, output_dependencies, graph));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,98 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Helper functions for TPU rewrite passes.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
|
||||
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class DistributedTPURewriteHelpers {
|
||||
public:
|
||||
// Given a user-assigned device string, system_spec_string, parse it into
|
||||
// system_spec. Verify that the device type is either TPU_SYSTEM or
|
||||
// unassigned, and in the latter case set it to TPU_SYSTEM:0. Having set the
|
||||
// type, verify that the spec matches a unique device in device_set, and
|
||||
// return that device in system_device. The normal use case is for
|
||||
// system_spec_string to identify the TPU_SYSTEM on replica 0, task 0 of the
|
||||
// job that contains the TPU hardware.
|
||||
// TODO(b/110910013): Possibly remove the tpu system device.
|
||||
static Status GetSystemDevice(const string& system_spec_string,
|
||||
const DeviceSet& device_set,
|
||||
DeviceNameUtils::ParsedName* system_spec,
|
||||
Device** system_device);
|
||||
|
||||
// Given a parsed system spec (e.g., the one returned above from
|
||||
// GetSystemDeviceName), return in host_devices the TPU_SYSTEM:0 device on
|
||||
// every host in the spec's job. If the spec does not include an explicit job,
|
||||
// "localhost" is used. Returns an error if system_spec matches devices from
|
||||
// a multiple jobs or replicas.
|
||||
static Status GetHostSystemDevices(
|
||||
const DeviceNameUtils::ParsedName& system_spec,
|
||||
const DeviceSet& device_set, std::vector<Device*>* host_system_devices);
|
||||
|
||||
// Given a parsed system spec (e.g., the one returned above from
|
||||
// GetSystemDeviceName), sets `*tpu_devices` to a per-host vector of the TPU
|
||||
// devices on every host in the spec's job. If the spec does not include an
|
||||
// explicit job, "localhost" is used. Sets `*num_tpus_per_host` to the number
|
||||
// of TPU devices in each host, and verifies that each host in the job has
|
||||
// the same number of TPU devices.
|
||||
// Returns an error if system_spec matches devices from a multiple jobs or
|
||||
// replicas.
|
||||
static Status GetTPUDevices(const DeviceNameUtils::ParsedName& system_spec,
|
||||
const DeviceSet& device_set,
|
||||
int* num_tpus_per_host,
|
||||
std::vector<std::vector<Device*>>* tpu_devices);
|
||||
|
||||
// Perform 'action' on every node in 'graph' of type
|
||||
// 'node_type'. This function is designed for use with configuration
|
||||
// Ops that have no inputs or outputs. The arguments passed to 'action' are:
|
||||
// 'configuration_node_name': the name of the node that matched
|
||||
// 'configuration_device_name': the name of the device that the
|
||||
// matching node is placed on
|
||||
// 'host_devices': the set of TPU_SYSTEM devices on hosts with TPUs that are
|
||||
// in the same system as the node that matched.
|
||||
// 'input_dependencies': the set of nodes that have control edges to
|
||||
// the matching node.
|
||||
// 'output_dependencies': the set of output port, destination node, input port
|
||||
// triples that have edges from the matching node. Input port is
|
||||
// Graph::kControlSlot for a control edge.
|
||||
// 'graph': the graph being mutated.
|
||||
struct OutputDependency {
|
||||
int src_output;
|
||||
Node* dst;
|
||||
int dst_input;
|
||||
};
|
||||
static Status ForConfigurationNodeMatchingType(
|
||||
const string& node_type, Graph* graph, const DeviceSet& device_set,
|
||||
const std::function<
|
||||
Status(const NodeDef& configuration_node_def,
|
||||
const string& configuration_device_name,
|
||||
const std::vector<Device*>& host_devices,
|
||||
const std::vector<Node*>& input_dependencies,
|
||||
const std::vector<OutputDependency>& output_dependencies,
|
||||
Graph* graph)>& action);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_
|
288
tensorflow/core/tpu/kernels/BUILD
Normal file
288
tensorflow/core/tpu/kernels/BUILD
Normal file
@ -0,0 +1,288 @@
|
||||
# TPU Kernel Implementations
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_proto_library_cc",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/core/tpu:__subpackages__",
|
||||
"//tensorflow/stream_executor/tpu:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compile_op_options",
|
||||
srcs = ["tpu_compile_op_options.cc"],
|
||||
hdrs = ["tpu_compile_op_options.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_configuration_ops",
|
||||
srcs = ["tpu_configuration_ops.cc"],
|
||||
hdrs = ["tpu_configuration_ops.h"],
|
||||
deps = [
|
||||
":tpu_mesh_state_interface",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core/tpu:tpu_config_c_api",
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/core/tpu:tpu_library_loader",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compile_c_api_hdrs",
|
||||
hdrs = ["tpu_compile_c_api.h"],
|
||||
deps = [
|
||||
":tpu_mesh_state_c_api",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "tpu_executable_info_proto",
|
||||
srcs = ["tpu_executable_info.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = [
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto",
|
||||
"//tensorflow/core:protos_all",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "tpu_compile_proto",
|
||||
srcs = ["tpu_compile.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = [
|
||||
":tpu_executable_info_proto",
|
||||
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto",
|
||||
"//tensorflow/core:protos_all",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_key",
|
||||
srcs = [],
|
||||
hdrs = [
|
||||
"tpu_compilation_cache_key.h",
|
||||
],
|
||||
deps = ["@com_google_absl//absl/types:optional"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compile_op_support",
|
||||
srcs = ["tpu_compile_op_support.cc"],
|
||||
hdrs = ["tpu_compile_op_support.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_compile_proto_cc",
|
||||
":tpu_executable_info_proto_cc",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/compiler/xla/service:computation_layout",
|
||||
"//tensorflow/compiler/xla/service:dump",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_group",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/framework:protos_all_cc",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
"//tensorflow/stream_executor/tpu:status_helper",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_entry",
|
||||
hdrs = [
|
||||
"tpu_compilation_cache_entry.h",
|
||||
],
|
||||
deps = [
|
||||
":tpu_executable_info_proto_cc",
|
||||
":tpu_program",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core/lib/core:refcount",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_lookup",
|
||||
srcs = ["tpu_compilation_cache_lookup.cc"],
|
||||
hdrs = [
|
||||
"tpu_compilation_cache_lookup.h",
|
||||
],
|
||||
deps = [
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_external",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
"//tensorflow/core/lib/core:refcount",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_mesh_state_c_api",
|
||||
hdrs = ["tpu_mesh_state_c_api.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_mesh_state_interface",
|
||||
srcs = [],
|
||||
hdrs = ["tpu_mesh_state_interface.h"],
|
||||
deps = [
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_mesh_state_c_api",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_config_c_api",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_program",
|
||||
srcs = ["tpu_program.cc"],
|
||||
hdrs = ["tpu_program.h"],
|
||||
deps = [
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_compile_op_support",
|
||||
":tpu_compile_proto_cc",
|
||||
":tpu_executable_info_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_group",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
"//tensorflow/stream_executor/tpu:status_helper",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_external",
|
||||
srcs = ["tpu_compilation_cache_external.cc"],
|
||||
hdrs = [
|
||||
"tpu_compilation_cache_external.h",
|
||||
],
|
||||
deps = [
|
||||
":tpu_compilation_cache_entry",
|
||||
":tpu_compilation_cache_key",
|
||||
":tpu_compilation_cache_metrics", # buildcleaner: keep
|
||||
":tpu_compilation_cache_metrics_hdrs",
|
||||
":tpu_compilation_cache_proto_cc",
|
||||
":tpu_compile_c_api_hdrs",
|
||||
":tpu_compile_op_support",
|
||||
":tpu_mesh_state_interface",
|
||||
":tpu_program",
|
||||
":tpu_util",
|
||||
":trace_util_hdrs",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_metrics_hdrs",
|
||||
hdrs = ["tpu_compilation_cache_metrics.h"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:types",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compilation_cache_metrics",
|
||||
srcs = ["tpu_compilation_cache_metrics.cc"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_metrics_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "trace_util_hdrs",
|
||||
srcs = [],
|
||||
hdrs = ["trace_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_util_hdrs",
|
||||
srcs = [],
|
||||
hdrs = ["tpu_util.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_key",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_util",
|
||||
srcs = ["tpu_util.cc"],
|
||||
hdrs = ["tpu_util.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_key",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "tpu_compilation_cache_proto",
|
||||
srcs = ["tpu_compilation_cache.proto"],
|
||||
cc_api_version = 2,
|
||||
)
|
25
tensorflow/core/tpu/kernels/tpu_compilation_cache.proto
Normal file
25
tensorflow/core/tpu/kernels/tpu_compilation_cache.proto
Normal file
@ -0,0 +1,25 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.tpu;
|
||||
|
||||
// Target type for compilation cache fetch operation.
|
||||
enum CompilationCacheFetchTarget {
|
||||
INVALID = 0;
|
||||
MAIN = 1;
|
||||
SHARDING = 2;
|
||||
UNSHARDING = 3;
|
||||
}
|
84
tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h
Normal file
84
tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h
Normal file
@ -0,0 +1,84 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
|
||||
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class CompilationCacheEntry {
|
||||
public:
|
||||
explicit CompilationCacheEntry(
|
||||
std::unique_ptr<const TpuProgram> tpu_program)
|
||||
: tpu_program_(std::move(tpu_program)) {}
|
||||
|
||||
// Constructor for an empty entry.
|
||||
CompilationCacheEntry()
|
||||
: tpu_program_(nullptr) {}
|
||||
|
||||
const TPUExecutableInfoProto* get_executable_info() const {
|
||||
return &tpu_program_->executable_info();
|
||||
}
|
||||
|
||||
const TPUHostTransferInfoProto* get_host_transfer_info() const {
|
||||
return &tpu_program_->host_transfer_info();
|
||||
}
|
||||
|
||||
const xla::HloProto* get_hlo_metadata() const {
|
||||
return &tpu_program_->hlo_metadata();
|
||||
}
|
||||
|
||||
// TODO(henrytan,jiawenhao): When should we expect more than one
|
||||
// XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then.
|
||||
const XLA_TpuProgram* get_tpu_program() const {
|
||||
CHECK_EQ(tpu_program_->program_count(), 1);
|
||||
return tpu_program_->tpu_programs()[0];
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<const TpuProgram> tpu_program_;
|
||||
};
|
||||
|
||||
// Base class for a reference to a cached proto. A unique_ptr to a
|
||||
// CompilationCacheEntryRef is returned by all the cache Lookup methods below,
|
||||
// and ensures the underlying proto is not garbage-collected until the client
|
||||
// discards the ptr.
|
||||
class CompilationCacheEntryRef {
|
||||
public:
|
||||
virtual ~CompilationCacheEntryRef() = default;
|
||||
|
||||
// Returns a CompilationCacheEntry that should not be used beyond the lifetime
|
||||
// of the CompilationCacheEntryRef.
|
||||
virtual CompilationCacheEntry get() = 0;
|
||||
};
|
||||
|
||||
// Base class that holds references to compiled protos so that the protos are
|
||||
// not garbage-collected before being used by execute ops. Use
|
||||
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
|
||||
// ref holder object.
|
||||
class CompilationRefHolder : public ResourceBase {
|
||||
public:
|
||||
~CompilationRefHolder() override = default;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_
|
791
tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc
Normal file
791
tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc
Normal file
@ -0,0 +1,791 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
#include "tensorflow/core/tpu/kernels/trace_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
namespace {
|
||||
|
||||
using CompilationEntry = TpuCompilationCacheInterface::CompilationEntry;
|
||||
|
||||
int64 get_uid() {
|
||||
uint64 unsigned_rand = random::New64() & INT64_MAX;
|
||||
return static_cast<int64>(unsigned_rand);
|
||||
}
|
||||
|
||||
void PopulateEntry(const std::string& key, CompilationEntry* entry,
|
||||
std::unique_ptr<TpuProgram> tpu_program) {
|
||||
// Make the unique keys for each cached proto.
|
||||
for (int i = 0; i < tpu_program->program_count(); ++i) {
|
||||
entry->proto_key.push_back(ProtoKeyForComputation(key, i));
|
||||
}
|
||||
|
||||
entry->tpu_program = std::move(tpu_program);
|
||||
entry->initialized = true;
|
||||
}
|
||||
|
||||
std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
|
||||
if (!key.has_guaranteed_const) {
|
||||
return key.prefix;
|
||||
}
|
||||
return absl::StrCat(key.prefix, "|", key.session_handle, "|",
|
||||
key.guaranteed_const_fingerprint());
|
||||
}
|
||||
|
||||
// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor
|
||||
// data to compute the fingerprint.
|
||||
std::string GuaranteedConstFingerprint(
|
||||
const string& fingerprint_in_metadata,
|
||||
const OpInputList& guaranteed_constants) {
|
||||
if (fingerprint_in_metadata.empty()) {
|
||||
uint64_t fingerprint = 0;
|
||||
for (const auto& constant : guaranteed_constants) {
|
||||
fingerprint = TpuCompile_CreateGuaranteedConstFingerprint(
|
||||
fingerprint, constant.tensor_data().data(),
|
||||
constant.tensor_data().size());
|
||||
}
|
||||
return std::to_string(fingerprint);
|
||||
} else {
|
||||
return fingerprint_in_metadata;
|
||||
}
|
||||
}
|
||||
|
||||
std::string CreateShapePrefix(
|
||||
const std::vector<tensorflow::TensorShape>& dynamic_shapes) {
|
||||
std::string shapes_prefix;
|
||||
for (const TensorShape& shape : dynamic_shapes) {
|
||||
for (int64 size : shape.dim_sizes()) {
|
||||
absl::StrAppend(&shapes_prefix, size, ",");
|
||||
}
|
||||
absl::StrAppend(&shapes_prefix, ";");
|
||||
}
|
||||
return shapes_prefix;
|
||||
}
|
||||
|
||||
// Include compilation configurations of the arguments that are not captured
|
||||
// by the called graph.
|
||||
std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) {
|
||||
std::string config_prefix;
|
||||
for (const auto& arg : metadata.args()) {
|
||||
if (arg.is_same_data_across_replicas()) {
|
||||
absl::StrAppend(&config_prefix, ":s");
|
||||
// Same.
|
||||
} else {
|
||||
// Different.
|
||||
absl::StrAppend(&config_prefix, ":");
|
||||
}
|
||||
if (arg.enable_xla_sharding() ==
|
||||
tpu::TPUCompileMetadataProto::Arg::ALLOWED) {
|
||||
// Enabled.
|
||||
absl::StrAppend(&config_prefix, "e");
|
||||
}
|
||||
if (arg.unrestricted_layout()) {
|
||||
// Unrestricted.
|
||||
absl::StrAppend(&config_prefix, ":u");
|
||||
}
|
||||
absl::StrAppend(&config_prefix, ",type(", arg.dtype(), ")");
|
||||
if (arg.has_shape()) {
|
||||
absl::StrAppend(&config_prefix, ",shape(");
|
||||
for (const auto& dim : arg.shape().dim()) {
|
||||
absl::StrAppend(&config_prefix, dim.size(), ",");
|
||||
}
|
||||
absl::StrAppend(&config_prefix, ")");
|
||||
}
|
||||
}
|
||||
return config_prefix;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TpuCompilationCacheInterface::TpuCompilationCacheInterface(
|
||||
int64_t max_cache_size)
|
||||
: max_cache_size_(max_cache_size) {
|
||||
if (max_cache_size < 0) {
|
||||
LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0";
|
||||
}
|
||||
VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
|
||||
}
|
||||
|
||||
TpuCompilationCacheInterface::~TpuCompilationCacheInterface() {
|
||||
VLOG(1) << "TpuCompilationCacheInterface::~TpuCompilationCacheInterface()";
|
||||
// A buggy client may be holding onto a reference, or a client might have
|
||||
// crashed while holding onto a reference. In either case, discard all
|
||||
// outstanding client references to avoid leaking storage.
|
||||
for (const auto& entry : entries_by_uid_) {
|
||||
while (entry.second->external_references > 0) {
|
||||
TF_CHECK_OK(Release(entry.first));
|
||||
}
|
||||
}
|
||||
while (!entries_by_last_use_.empty()) {
|
||||
UnloadAndDestroy(MarkOldestEntryForEviction());
|
||||
}
|
||||
// By the time the cache is deleted all reference holders should have already
|
||||
// been deleted, since they were holding references to the cache. So all
|
||||
// entries should be gone at this point.
|
||||
CHECK_EQ(cache_store_.size(), 0);
|
||||
CHECK_EQ(entries_by_uid_.size(), 0);
|
||||
CHECK_EQ(entries_by_proto_key_.size(), 0);
|
||||
CHECK_EQ(cache_size_, 0);
|
||||
CHECK_EQ(marked_for_eviction_size_, 0);
|
||||
}
|
||||
|
||||
std::string TpuCompilationCacheInterface::FindCacheKey(
|
||||
const TpuCompilationCacheKey& subgraph_key) const {
|
||||
if (!subgraph_key.has_guaranteed_const) {
|
||||
return subgraph_key.prefix;
|
||||
}
|
||||
auto iter = session_key_map_.find(
|
||||
strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
|
||||
if (iter != session_key_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
iter = fingerprint_key_map_.find(strings::StrCat(
|
||||
subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
|
||||
if (iter != session_key_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
VLOG(1) << "No matching cache key found for key "
|
||||
<< ConstructCompilationCacheKey(subgraph_key);
|
||||
return "";
|
||||
}
|
||||
|
||||
void TpuCompilationCacheInterface::InsertEntry(
|
||||
const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key,
|
||||
CompilationEntry* entry) {
|
||||
entry->parent = this;
|
||||
entry->subgraph_key = cache_key;
|
||||
entry->uid = get_uid();
|
||||
TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
|
||||
entry->cache_entry_debug_string = subgraph_key.prefix;
|
||||
VLOG(1) << "Cache Initializing Entry Session Debug "
|
||||
<< entry->cache_entry_debug_string;
|
||||
|
||||
if (!subgraph_key.has_guaranteed_const) {
|
||||
return;
|
||||
}
|
||||
session_key_map_.insert(std::make_pair(
|
||||
strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle),
|
||||
cache_key));
|
||||
fingerprint_key_map_.insert(std::make_pair(
|
||||
strings::StrCat(subgraph_key.prefix,
|
||||
subgraph_key.guaranteed_const_fingerprint()),
|
||||
cache_key));
|
||||
}
|
||||
|
||||
CompilationEntry* TpuCompilationCacheInterface::InitializeEntry(
|
||||
const string& key,
|
||||
const std::function<Status(TpuProgram*)>& initialize_program,
|
||||
const TpuCompilationCacheKey& subgraph_key) {
|
||||
CompilationEntry* main_entry = new CompilationEntry();
|
||||
|
||||
// Add the entry to the cache, with size zero since there are no compiled
|
||||
// programs in it. Once the subgraph has been compiled,
|
||||
// UpdateEntryAfterCompilation will be called to potentially mark old entries
|
||||
// that don't fit any more for eviction.
|
||||
//
|
||||
// At this point there is one reference to entry, which is owned by the caller
|
||||
// who created the entry. A second reference, owned by the cache, will be
|
||||
// added below since we leave the entry in the 'marked for eviction' state
|
||||
// here.
|
||||
InsertEntry(key, subgraph_key, main_entry);
|
||||
|
||||
// Initialize the programs outside the lock so that other cache operations
|
||||
// can proceed during the (potentially lengthy) initialization.
|
||||
Status initialization_status;
|
||||
|
||||
auto tpu_program = absl::make_unique<TpuProgram>();
|
||||
{
|
||||
mu_.Unlock();
|
||||
{
|
||||
profiler::TraceMe compile_programs_traceme(
|
||||
"TPU compilation cache compile",
|
||||
/*level=*/2);
|
||||
initialization_status = initialize_program(tpu_program.get());
|
||||
}
|
||||
mu_.Lock();
|
||||
}
|
||||
|
||||
main_entry->initialization_status = initialization_status;
|
||||
|
||||
// Add the entry to the uid index.
|
||||
auto uid_inserted = entries_by_uid_.insert(
|
||||
std::pair<int64, CompilationEntry*>(main_entry->uid, main_entry));
|
||||
CHECK(uid_inserted.second);
|
||||
|
||||
if (initialization_status.ok()) {
|
||||
// Compute the entries total size once all members are initialized.
|
||||
main_entry->total_size = tpu_program->program_size();
|
||||
}
|
||||
|
||||
// TODO(henrytan): handle sharding/unsharding.
|
||||
PopulateEntry(key, main_entry, std::move(tpu_program));
|
||||
|
||||
for (int64 i = 0; i < main_entry->proto_key.size(); ++i) {
|
||||
auto entry_inserted = entries_by_proto_key_.insert(
|
||||
std::pair<string, std::pair<CompilationEntry*, int>>(
|
||||
main_entry->proto_key[i], std::make_pair(main_entry, i)));
|
||||
CHECK(entry_inserted.second);
|
||||
}
|
||||
|
||||
// Add the size to marked_for_eviction_size_ since it will be adjusted down
|
||||
// again when the newly-created entry gets unmarked.
|
||||
marked_for_eviction_size_ += main_entry->total_size;
|
||||
return main_entry;
|
||||
}
|
||||
|
||||
/*static*/ TpuCompilationCacheKey
|
||||
TpuCompilationCacheInterface::CreateCompilationCacheKey(
|
||||
absl::string_view function_name, uint64 function_library_fingerprint,
|
||||
absl::string_view mlir_module,
|
||||
const tensorflow::OpInputList& guaranteed_constants,
|
||||
const std::vector<tensorflow::TensorShape>& dynamic_shapes,
|
||||
const tensorflow::tpu::TPUCompileMetadataProto& metadata,
|
||||
const TpuMeshStateInterface& mesh_state) {
|
||||
VLOG(1) << "FunctionLibraryFingerprint:" << function_library_fingerprint;
|
||||
std::string shapes_prefix = CreateShapePrefix(dynamic_shapes);
|
||||
VLOG(1) << "shapes_prefix = " << shapes_prefix;
|
||||
std::string config_prefix = CreateConfigPrefix(metadata);
|
||||
VLOG(1) << "config_prefix = " << config_prefix;
|
||||
std::vector<int32_t> flattened_device_ids;
|
||||
if (metadata.has_device_assignment()) {
|
||||
for (const auto& device :
|
||||
metadata.device_assignment().computation_devices()) {
|
||||
flattened_device_ids.insert(flattened_device_ids.end(),
|
||||
device.replica_device_ids().begin(),
|
||||
device.replica_device_ids().end());
|
||||
}
|
||||
}
|
||||
// TODO(henrytan): return the debug_string.
|
||||
const char* prefix =
|
||||
TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty{
|
||||
config_prefix.data(),
|
||||
shapes_prefix.data(),
|
||||
function_name.data(),
|
||||
mlir_module.data(),
|
||||
flattened_device_ids.data(),
|
||||
flattened_device_ids.size(),
|
||||
guaranteed_constants.size(),
|
||||
function_library_fingerprint,
|
||||
metadata.num_cores_per_replica(),
|
||||
metadata.num_replicas(),
|
||||
mesh_state.data(),
|
||||
});
|
||||
auto buffer_cleanup = gtl::MakeCleanup([prefix]() { delete[] prefix; });
|
||||
TpuCompilationCacheKey key;
|
||||
key.prefix = prefix;
|
||||
|
||||
// Guaranteed constants can be different across sessions. Use session_handle
|
||||
// and guaranteed_const fingerprint to guarantee no collision.
|
||||
if (guaranteed_constants.size() > 0) {
|
||||
key.has_guaranteed_const = true;
|
||||
key.session_handle = metadata.session_handle();
|
||||
// Both `metadata` and `guaranteed_constants` lifetime are captured by
|
||||
// reference based on the assumption that these variables lifetime is
|
||||
// managed through the `TPUCompileOpKernelImpl` that outlives the
|
||||
// lifetime of the compilation cache lookups.
|
||||
string fingerprint;
|
||||
key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants,
|
||||
fingerprint]() mutable {
|
||||
if (fingerprint.empty()) {
|
||||
fingerprint = GuaranteedConstFingerprint(
|
||||
metadata.guaranteed_const_fingerprint(), guaranteed_constants);
|
||||
}
|
||||
return fingerprint;
|
||||
};
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
TpuCompilationRefHolder* TpuCompilationCacheInterface::MakePerStepRefHolder() {
|
||||
return new RefHolder(this);
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheInterface::MarkEntryForEviction(int64 subgraph_uid) {
|
||||
profiler::TraceMe key_release_traceme(
|
||||
"TPU compilation cache possibly evict uid",
|
||||
/*level=*/2);
|
||||
CompilationEntry* deleted_entry = nullptr;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto iter = entries_by_uid_.find(subgraph_uid);
|
||||
if (iter == entries_by_uid_.end()) {
|
||||
// If already evicted, return ok.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Mark entry for eviction.
|
||||
CompilationEntry* subgraph_to_evict = iter->second;
|
||||
// If there are external references, should not use this API.
|
||||
if (subgraph_to_evict->external_references != 0) {
|
||||
return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
|
||||
" external_references greater than zero. Should "
|
||||
"use TpuCompilationCache::Release.");
|
||||
}
|
||||
|
||||
VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction";
|
||||
entries_by_last_use_.erase(subgraph_to_evict->last_use);
|
||||
cache_size_ -= subgraph_to_evict->total_size;
|
||||
marked_for_eviction_size_ += subgraph_to_evict->total_size;
|
||||
|
||||
// Evict if refcount exactly one, otherwise only discard cache's reference
|
||||
// to the entry while the actual eviction will happen when refholder's
|
||||
// references go away.
|
||||
deleted_entry = DiscardEntryRef(subgraph_to_evict);
|
||||
|
||||
VLOG(1) << "After possibly evicting entry " << subgraph_uid
|
||||
<< " refs cache is " << cache_store_.size() << " entries ("
|
||||
<< cache_size_ + marked_for_eviction_size_
|
||||
<< " bytes), marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
}
|
||||
|
||||
// Unload from device cache if entry is evicted from host cache.
|
||||
UnloadAndDestroy(deleted_entry);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheInterface::Release(int64 subgraph_uid) {
|
||||
profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
|
||||
/*level=*/2);
|
||||
|
||||
CompilationEntry* deleted_entry = nullptr;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto iter = entries_by_uid_.find(subgraph_uid);
|
||||
|
||||
if (iter == entries_by_uid_.end()) {
|
||||
return errors::NotFound("No cache entry found for uid ", subgraph_uid);
|
||||
}
|
||||
|
||||
CHECK_GT(iter->second->external_references, 0);
|
||||
--iter->second->external_references;
|
||||
|
||||
deleted_entry = DiscardEntryRef(iter->second);
|
||||
|
||||
VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
|
||||
<< cache_store_.size() << " entries ("
|
||||
<< cache_size_ + marked_for_eviction_size_
|
||||
<< " bytes), marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
}
|
||||
UnloadAndDestroy(deleted_entry);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TpuCompilationCacheInterface::UnloadAndDestroy(CompilationEntry* entry) {
|
||||
if (!entry) return;
|
||||
|
||||
CHECK(entry->RefCountIsOne());
|
||||
entry->tpu_program->UnloadAndDestroyPrograms();
|
||||
entry->Unref();
|
||||
}
|
||||
|
||||
size_t TpuCompilationCacheInterface::RemoveEntry(const string& key) {
|
||||
auto erased = cache_store_.erase(key);
|
||||
TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size());
|
||||
auto parsed_key_or_status = ParseCompilationCacheKey(key);
|
||||
CHECK(parsed_key_or_status.status().ok());
|
||||
const TpuCompilationCacheKey parsed_key =
|
||||
parsed_key_or_status.ConsumeValueOrDie();
|
||||
if (!parsed_key.has_guaranteed_const) {
|
||||
return erased;
|
||||
}
|
||||
session_key_map_.erase(
|
||||
strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
|
||||
fingerprint_key_map_.erase(strings::StrCat(
|
||||
parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
|
||||
return erased;
|
||||
}
|
||||
|
||||
ABSL_MUST_USE_RESULT CompilationEntry*
|
||||
TpuCompilationCacheInterface::DiscardEntryRef(CompilationEntry* entry) {
|
||||
if (entry->RefCountIsOne()) {
|
||||
// The last reference to this entry is going away, so really delete it from
|
||||
// the cache in such a way that it can't be restored by being looked up
|
||||
// again.
|
||||
|
||||
// Sanity-check that it has been marked for eviction.
|
||||
CHECK(entries_by_last_use_.find(entry->last_use) ==
|
||||
entries_by_last_use_.end());
|
||||
// Update the counter tracking how much space is taken up by entries that
|
||||
// are marked for eviction.
|
||||
marked_for_eviction_size_ -= entry->total_size;
|
||||
|
||||
// Remove the entry from the cache.
|
||||
auto erased = RemoveEntry(entry->subgraph_key);
|
||||
|
||||
if (erased == 0) {
|
||||
LOG(FATAL) << "Tried to discard nonexistent cache entry";
|
||||
}
|
||||
erased = entries_by_uid_.erase(entry->uid);
|
||||
CHECK_EQ(erased, 1);
|
||||
for (const string& key : entry->proto_key) {
|
||||
erased = entries_by_proto_key_.erase(key);
|
||||
CHECK_EQ(erased, 1);
|
||||
}
|
||||
// The actual deletion will happen outside the lock in UnloadAndDestroy().
|
||||
return entry;
|
||||
}
|
||||
entry->Unref();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void TpuCompilationCacheInterface::DiscardEntryRefs(
|
||||
gtl::ArraySlice<CompilationEntry*> entries) {
|
||||
std::vector<CompilationEntry*> removed_entries;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
for (auto entry : entries) {
|
||||
removed_entries.push_back(DiscardEntryRef(entry));
|
||||
}
|
||||
|
||||
VLOG(1) << "After discarding entry refs cache is " << cache_store_.size()
|
||||
<< " entries (" << cache_size_ + marked_for_eviction_size_
|
||||
<< " bytes), marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
}
|
||||
for (auto removed_entry : removed_entries) {
|
||||
UnloadAndDestroy(removed_entry);
|
||||
}
|
||||
}
|
||||
|
||||
ABSL_MUST_USE_RESULT CompilationEntry*
|
||||
TpuCompilationCacheInterface::MarkOldestEntryForEviction() {
|
||||
CompilationEntry* entry_to_mark = entries_by_last_use_.begin()->second;
|
||||
VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction";
|
||||
entries_by_last_use_.erase(entry_to_mark->last_use);
|
||||
cache_size_ -= entry_to_mark->total_size;
|
||||
marked_for_eviction_size_ += entry_to_mark->total_size;
|
||||
// Discard the cache's reference to entry. If steps are holding onto
|
||||
// references to entry it won't be deleted until the last step holding it
|
||||
// completes. It stays in the cache in the meantime and can be resurrected
|
||||
// by a call to CompileIfKeyAbsent if that occurs before the last reference
|
||||
// expires.
|
||||
return DiscardEntryRef(entry_to_mark);
|
||||
}
|
||||
|
||||
void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
|
||||
CompilationEntry* entry, std::vector<CompilationEntry*>* removed_entries) {
|
||||
// The entry was previously marked for eviction (or is newly created) so
|
||||
// unmark it. Add a reference (owned by the cache), update the cache size, and
|
||||
// mark something old for eviction if necessary.
|
||||
entry->Ref();
|
||||
marked_for_eviction_size_ -= entry->total_size;
|
||||
cache_size_ += entry->total_size;
|
||||
|
||||
// Mark the least-recently-used non-marked entry for eviction. Never mark the
|
||||
// most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
|
||||
// which means there's only one entry not already marked for eviction), so
|
||||
// that an entry persists in the cache even if it is larger than the allocated
|
||||
// cache size.
|
||||
while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
|
||||
if (auto entry_to_evict = MarkOldestEntryForEviction()) {
|
||||
removed_entries->push_back(entry_to_evict);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheInterface::ToSubEntryRef(
|
||||
CompilationCacheEntryRef* entry,
|
||||
CompilationCacheFetchTarget fetch_target) const {
|
||||
return static_cast<EntryRefImpl*>(entry)->ToSubEntryRef(fetch_target);
|
||||
}
|
||||
|
||||
TpuCompilationCacheInterface::EntryRefImpl::EntryRefImpl(
|
||||
TpuCompilationCacheInterface* parent, CompilationEntry* entry, int index)
|
||||
: parent_(parent), entry_(entry), index_(index) {
|
||||
if (entry_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (entry_->main_entry == nullptr) {
|
||||
entry_->Ref();
|
||||
} else {
|
||||
// This is a sharding/unsharding entry nested in a main entry. Only refcount
|
||||
// the main entry.
|
||||
entry_->main_entry->Ref();
|
||||
}
|
||||
}
|
||||
|
||||
TpuCompilationCacheInterface::EntryRefImpl::~EntryRefImpl() {
|
||||
if (entry_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (entry_->main_entry == nullptr) {
|
||||
parent_->DiscardEntryRefs({entry_});
|
||||
} else {
|
||||
parent_->DiscardEntryRefs({entry_->main_entry});
|
||||
}
|
||||
}
|
||||
|
||||
CompilationCacheEntry TpuCompilationCacheInterface::EntryRefImpl::get() {
|
||||
if (entry_ == nullptr) {
|
||||
// Create an empty entry if the entry is nullptr. This corresponds to
|
||||
// non-existing sharding/unsharding entries.
|
||||
return CompilationCacheEntry();
|
||||
}
|
||||
return CompilationCacheEntry(std::move(entry_->tpu_program));
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheInterface::EntryRefImpl::ToSubEntryRef(
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
CompilationEntry* target = nullptr;
|
||||
switch (fetch_target) {
|
||||
case CompilationCacheFetchTarget::MAIN:
|
||||
target = entry_;
|
||||
break;
|
||||
case CompilationCacheFetchTarget::SHARDING:
|
||||
target = entry_->sharding_entry.get();
|
||||
break;
|
||||
case CompilationCacheFetchTarget::UNSHARDING:
|
||||
target = entry_->unsharding_entry.get();
|
||||
break;
|
||||
default:
|
||||
return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
|
||||
}
|
||||
|
||||
if (target == nullptr) {
|
||||
// Cache entry does not have an unsharding subentry. Unref and replace
|
||||
// with nullptr.
|
||||
parent_->DiscardEntryRefs({entry_});
|
||||
}
|
||||
// Otherwise, since the refcount is always on the main entry, we don't need
|
||||
// ref/unref.
|
||||
entry_ = target;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheInterface::Lookup(
|
||||
int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
entry->reset();
|
||||
|
||||
profiler::TraceMe proto_lookup_traceme(
|
||||
"TPU compilation cache proto lookup by uid",
|
||||
/*level=*/2);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto iter = entries_by_uid_.find(uid);
|
||||
if (iter == entries_by_uid_.end()) {
|
||||
return errors::NotFound("No subgraph found for uid ", uid);
|
||||
}
|
||||
CompilationEntry* cache_entry = iter->second;
|
||||
if (proto_index < 0 ||
|
||||
proto_index >= cache_entry->tpu_program->program_size()) {
|
||||
return errors::NotFound("No proto found for core index ", proto_index,
|
||||
" in subgraph with uid ", uid);
|
||||
}
|
||||
*entry = std::unique_ptr<CompilationCacheEntryRef>(
|
||||
new EntryRefImpl(this, cache_entry, proto_index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheInterface::Lookup(
|
||||
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
entry->reset();
|
||||
|
||||
profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
|
||||
/*level=*/2);
|
||||
|
||||
absl::MutexLock lock(&mu_);
|
||||
const auto iter = entries_by_proto_key_.find(proto_key);
|
||||
if (iter == entries_by_proto_key_.end()) {
|
||||
return errors::NotFound("No proto found for key ", proto_key);
|
||||
}
|
||||
CompilationEntry* cache_entry = iter->second.first;
|
||||
int proto_index = iter->second.second;
|
||||
*entry = std::unique_ptr<CompilationCacheEntryRef>(
|
||||
new EntryRefImpl(this, cache_entry, proto_index));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<CompilationEntry*>* removed_entries,
|
||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||
const std::function<Status(TpuProgram*)>& compile_function) {
|
||||
profiler::TraceMe subgraph_lookup_traceme(
|
||||
"TPU compilation cache subgraph lookup",
|
||||
/*level=*/2);
|
||||
|
||||
// NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
|
||||
// for the lifetime of the object, see InitializeEntry() call below.
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
||||
std::string cache_key = FindCacheKey(subgraph_key);
|
||||
auto iter = cache_store_.find(cache_key);
|
||||
bool is_new_key = iter == cache_store_.end();
|
||||
|
||||
const std::string session_name = SessionNameFromMetadata(session_metadata);
|
||||
|
||||
CompilationEntry* entry = nullptr;
|
||||
if (is_new_key) {
|
||||
cache_key = ConstructCompilationCacheKey(subgraph_key);
|
||||
TpuCompilationCacheMetrics::IncrementCacheLookupCount(
|
||||
/*is_cache_hit=*/false, session_name);
|
||||
const string msg =
|
||||
strings::StrCat("TPU host compilation cache miss: cache_key(",
|
||||
cache_key, "), session_name(", session_name, ")");
|
||||
|
||||
TRACESTRING(msg);
|
||||
LOG(INFO) << msg;
|
||||
|
||||
// Check if caller has disabled compilation. Set using
|
||||
// internal::ScopedTpuCompileDisabler.
|
||||
if (!IsTpuCompilationEnabled()) {
|
||||
const string error_msg = strings::StrCat(
|
||||
"[TpuCompilationDisabled]: Compilation cache miss, but compilation "
|
||||
"disabled, session_name(",
|
||||
session_name, ") Debug String: ", subgraph_key.debug_string);
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "Cache Missed. Current cache entries: ";
|
||||
for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) {
|
||||
// TODO(henrytan): add DebugKey as cache_entry_debug_string to
|
||||
// TpuCompilationCacheKey.
|
||||
VLOG(2) << "Cache Debug Info: ";
|
||||
VLOG(2) << it->second->cache_entry_debug_string;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
|
||||
return errors::NotFound(error_msg);
|
||||
}
|
||||
|
||||
// The single ref on the newly-created entry is owned by the caller.
|
||||
VLOG(1) << "Before adding new entry for key " << cache_key
|
||||
<< " with session_name( " << session_name << ");"
|
||||
<< "; cache is " << cache_store_.size() << " entries ("
|
||||
<< cache_size_ + marked_for_eviction_size_ << " bytes), "
|
||||
<< " marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
// Note that InitializeEntry() will Release/Reacquire mu_.
|
||||
entry = InitializeEntry(cache_key, compile_function, subgraph_key);
|
||||
TRACELITERAL("TPU host compilation cache: compilation done.");
|
||||
|
||||
LOG(INFO) << strings::StrCat(
|
||||
"TPU host compilation cache: compilation done for cache_key(",
|
||||
cache_key, "), session_name(", session_name, ")");
|
||||
// If session_name is present, log some additional stats related to HBM
|
||||
// here, so that they can be associated directly to the session.
|
||||
if (!session_name.empty()) {
|
||||
entry->tpu_program->LogProgramMemorySummary();
|
||||
}
|
||||
} else {
|
||||
TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name);
|
||||
const string msg =
|
||||
strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
|
||||
"), session_name(", session_name, ")");
|
||||
TRACESTRING(msg);
|
||||
VLOG(1) << msg;
|
||||
VLOG(1) << "Before refreshing entry for key " << cache_key
|
||||
<< " with session_name( " << session_name << "); cache is "
|
||||
<< cache_store_.size() << " entries ("
|
||||
<< cache_size_ + marked_for_eviction_size_ << " bytes), "
|
||||
<< " marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
entry = iter->second;
|
||||
// Make a new reference that is owned by the caller.
|
||||
entry->Ref();
|
||||
// Block if necessary until the subgraph has been initialized.
|
||||
mu_.Await(absl::Condition(
|
||||
+[](CompilationEntry* e) { return e->initialized; }, entry));
|
||||
}
|
||||
|
||||
// Let the caller know the uid of the entry.
|
||||
*uid = entry->uid;
|
||||
// Let the caller know the keys for each of the cached protos.
|
||||
*proto_key = entry->proto_key;
|
||||
*may_modify_variables = entry->tpu_program->may_modify_variables();
|
||||
*hlo_metadata = entry->hlo_metadata;
|
||||
|
||||
// If the caller didn't supply a per_step_ref_holder then the caller is going
|
||||
// to manually release the reference later via a call to Release().
|
||||
if (per_step_ref_holder == nullptr) {
|
||||
++entry->external_references;
|
||||
} else {
|
||||
// The caller wants its reference to be handed off to a per-step holder that
|
||||
// will discard the reference when the step completes.
|
||||
RefHolder* cast_ref_holder = static_cast<RefHolder*>(per_step_ref_holder);
|
||||
TF_RET_CHECK(cast_ref_holder != nullptr);
|
||||
cast_ref_holder->AddRef(entry);
|
||||
}
|
||||
|
||||
// Remove the old LRU-table entry if it wasn't already marked for eviction.
|
||||
auto erased = entries_by_last_use_.erase(entry->last_use);
|
||||
// Update the LRU table indicating this entry is the most recently used.
|
||||
entry->last_use = use_counter_++;
|
||||
entries_by_last_use_[entry->last_use] = entry;
|
||||
if (erased == 0) {
|
||||
// The entry had been marked for eviction, or is newly created.
|
||||
LookupEntryMarkedForEviction(entry, removed_entries);
|
||||
}
|
||||
|
||||
// Log a little more verbosely when a key is added.
|
||||
if (VLOG_IS_ON(1) || is_new_key) {
|
||||
LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
|
||||
<< " entry for key " << cache_key << " with session_name "
|
||||
<< session_name << " cache is " << cache_store_.size()
|
||||
<< " entries (" << cache_size_ + marked_for_eviction_size_
|
||||
<< " bytes), "
|
||||
<< " marked for eviction "
|
||||
<< (cache_store_.size() - entries_by_last_use_.size())
|
||||
<< " entries (" << marked_for_eviction_size_ << " bytes).";
|
||||
}
|
||||
return entry->initialization_status;
|
||||
}
|
||||
|
||||
tensorflow::Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
|
||||
const TpuCompilationCacheKey& cache_key,
|
||||
const tensorflow::SessionMetadata* session_metadata,
|
||||
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||
const std::function<tensorflow::Status(TpuProgram*)>& compile_function) {
|
||||
std::vector<CompilationEntry*> removed_entries;
|
||||
auto status = CompileIfKeyAbsentHelper(
|
||||
cache_key, session_metadata, per_step_ref_holder, uid, proto_key,
|
||||
may_modify_variables, &removed_entries, hlo_metadata, compile_function);
|
||||
for (auto entry : removed_entries) {
|
||||
UnloadAndDestroy(entry);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
394
tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
Normal file
394
tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h
Normal file
@ -0,0 +1,394 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
|
||||
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
const char kCompilationCacheResourceName[] = "tpu_compilation_cache";
|
||||
const char kCompilationCacheUnloaderResourceName[] =
|
||||
"tpu_compilation_cache_unloader";
|
||||
|
||||
// Base class that holds references to compiled protos so that the protos are
|
||||
// not garbage-collected before being used by execute ops. Use
|
||||
// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
|
||||
// ref holder object.
|
||||
class TpuCompilationRefHolder : public ResourceBase {
|
||||
public:
|
||||
~TpuCompilationRefHolder() override = default;
|
||||
};
|
||||
|
||||
class TpuCompilationCacheInterface : public ResourceBase {
|
||||
public:
|
||||
using Status = ::stream_executor::port::Status;
|
||||
|
||||
// An entry in the compilation cache. The entry is deleted once it has been
|
||||
// marked for eviction from the cache _and_ all steps that use it have
|
||||
// completed. When the entry is first created, it is uninitialized and a
|
||||
// client-supplied compilation function is run outside the cache's lock to
|
||||
// generate the programs to be stored in the entry. Any other client that
|
||||
// requests the entry will block until it has been initialized. Each entry has
|
||||
// a last_use value that set from a monotonically-increasing counter in the
|
||||
// cache whenever the entry is referenced. When the cache becomes full,
|
||||
// entries are marked for eviction in LRU order.
|
||||
//
|
||||
// The bridge can request XLA to generate separate sharding and unsharding
|
||||
// programs along with the main program; we use nested fields sharding_entry,
|
||||
// unsharding_entry to store them under the main entry, and these two fields
|
||||
// must be either both present or both absent. They have a back pointer
|
||||
// main_entry to refer to the main program. These nested entries share the
|
||||
// same cache key and the same lifetime as the main entry, so we use the
|
||||
// refcount on the main entry to track the access to any of them.
|
||||
// /-------------------------------\
|
||||
// v \
|
||||
// main_entry (refcount) -> sharding_entry -> main_entry
|
||||
// ^ \
|
||||
// | \-> unsharding_entry -> main_entry
|
||||
// \--------------------------------------/
|
||||
struct CompilationEntry : public core::RefCounted {
|
||||
TpuCompilationCacheInterface* parent = nullptr; // Not owned.
|
||||
bool initialized = false;
|
||||
|
||||
// The Status returned by the compilation function when the entry is
|
||||
// initialized. This status will be returned to any client that requests the
|
||||
// entry.
|
||||
Status initialization_status;
|
||||
|
||||
// The uid describing this entry.
|
||||
int64 uid;
|
||||
std::vector<string> proto_key;
|
||||
|
||||
// Counter to keep track of LRU entries for the eviction policy.
|
||||
int64 last_use = -1;
|
||||
|
||||
// The unique key describing this entry.
|
||||
std::string subgraph_key;
|
||||
|
||||
// Entries representing the associated sharding and unsharding programs,
|
||||
// which share the same life time of the owning main entry, so we always use
|
||||
// the main entry's ref count.
|
||||
std::unique_ptr<CompilationEntry> sharding_entry;
|
||||
std::unique_ptr<CompilationEntry> unsharding_entry;
|
||||
|
||||
// The number of 'external' client-held references to the entry.
|
||||
int external_references = 0;
|
||||
|
||||
std::vector<std::shared_ptr<const xla::HloProto>> hlo_metadata;
|
||||
|
||||
// The sum of the SpaceUsed of each of the elements of programs; an estimate
|
||||
// of how much RAM the entry consumes, used to determine when entries must
|
||||
// be marked for eviction.
|
||||
int64 total_size = 0;
|
||||
|
||||
// Only used for the nested sharding/unsharding entries to point to the
|
||||
// owning main entry.
|
||||
CompilationEntry* main_entry = nullptr;
|
||||
|
||||
// Debug info in case we miss.
|
||||
string cache_entry_debug_string;
|
||||
|
||||
// Compiled Tpu program.
|
||||
std::unique_ptr<TpuProgram> tpu_program;
|
||||
};
|
||||
|
||||
explicit TpuCompilationCacheInterface(int64_t max_cache_size);
|
||||
~TpuCompilationCacheInterface() override;
|
||||
TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete;
|
||||
TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&)
|
||||
= delete;
|
||||
|
||||
Status CompileIfKeyAbsent(
|
||||
const TpuCompilationCacheKey& cache_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||
const std::function<tensorflow::Status(TpuProgram*)>& compile_function);
|
||||
|
||||
static TpuCompilationCacheKey CreateCompilationCacheKey(
|
||||
absl::string_view function_name, uint64 function_library_fingerprint,
|
||||
absl::string_view mlir_module,
|
||||
const tensorflow::OpInputList& guaranteed_constants,
|
||||
const std::vector<tensorflow::TensorShape>& dynamic_shapes,
|
||||
const tensorflow::tpu::TPUCompileMetadataProto& metadata,
|
||||
const TpuMeshStateInterface& mesh_state);
|
||||
|
||||
string DebugString() const override { return "TpuCompilationCacheInterface"; }
|
||||
|
||||
// Makes a reference holder for this cache, that can be stored in the per-step
|
||||
// resource manager and will ensure that compiled entries persist until the
|
||||
// end of a step.
|
||||
TpuCompilationRefHolder* MakePerStepRefHolder();
|
||||
|
||||
// Differences between MarkEntryForEviction and Release:
|
||||
// There are two modes of managing cache entries:
|
||||
// 1) LRU eviction + pinning; 2) manual.
|
||||
// We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
|
||||
// Otherwise it is manual mode (mainly used by XRT).
|
||||
// MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
|
||||
// entries when callers know that they do not need them anymore.
|
||||
// Release should only be used in mode 2) to explicitly remove an entry.
|
||||
|
||||
// Mark the entry indexed by `subgraph_uid` for eviction. This should only be
|
||||
// called if per_step_ref_holder was NOT nullptr in the corresponding call to
|
||||
// CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
|
||||
// subgraph_uid).
|
||||
Status MarkEntryForEviction(int64 subgraph_uid);
|
||||
|
||||
// Manually discards a reference to the compiled subgraph. This should only be
|
||||
// called if per_step_ref_holder was nullptr in the corresponding call to
|
||||
// CompileIfKeyAbsent(subgraph_key, ...).
|
||||
Status Release(int64 subgraph_uid);
|
||||
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
// the subgraph represented by key. On success a pointer to an EntryRef
|
||||
// holding the program is returned in entry.
|
||||
Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry);
|
||||
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
// the subgraph represented by uid. On success a pointer to an EntryRef
|
||||
// holding the program is returned in entry.
|
||||
Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry);
|
||||
|
||||
// Mutates the main entry ref to point to the entry's subentry
|
||||
// (for sharding/unsharding) or main entry (unchanged) representing the
|
||||
// fetch target. The entry ref needs to point to the main entry before this
|
||||
// call.
|
||||
//
|
||||
// If the requested subentry does not exist, the ref will point to a nullptr
|
||||
// entry.
|
||||
Status ToSubEntryRef(CompilationCacheEntryRef* entry,
|
||||
CompilationCacheFetchTarget fetch_target) const;
|
||||
|
||||
private:
|
||||
// Wrapper for a cache entry that holds a reference to the entry until the
|
||||
// wrapper is deleted. This wrapper is the concrete type of
|
||||
// CompilationCacheEntryRef returned by Lookup.
|
||||
class EntryRefImpl : public CompilationCacheEntryRef {
|
||||
public:
|
||||
EntryRefImpl(TpuCompilationCacheInterface* parent, CompilationEntry* entry,
|
||||
int index);
|
||||
~EntryRefImpl() override;
|
||||
|
||||
CompilationCacheEntry get() override;
|
||||
|
||||
// Mutates this ref to point to the entry's subentry (for
|
||||
// sharding/unsharding) or main entry (unchanged) as specified by
|
||||
// fetch_target. The refcount is kept unchanged, since we only track the
|
||||
// refcount of the main entry. The entry ref needs to point to the main
|
||||
// entry before this call.
|
||||
//
|
||||
// If the requested subentry does not exist, the ref will point to a nullptr
|
||||
// entry, and the original entry will be unref'ed.
|
||||
Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
|
||||
|
||||
private:
|
||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
||||
// A reference to entry_ is acquired in the constructor and released via
|
||||
// parent->DiscardEntryRefs in the destructor.
|
||||
CompilationEntry* entry_;
|
||||
// The program in entry_ that is returned by the get method.
|
||||
int index_;
|
||||
};
|
||||
|
||||
// Private implementation of the generic CompilationRefHolder that knows about
|
||||
// CompiledSubgraph entries.
|
||||
class RefHolder : public TpuCompilationRefHolder {
|
||||
public:
|
||||
explicit RefHolder(TpuCompilationCacheInterface* parent) : parent_(parent) {
|
||||
parent_->Ref();
|
||||
}
|
||||
~RefHolder() override {
|
||||
// Release our reference to the parent.
|
||||
parent_->Unref();
|
||||
}
|
||||
|
||||
// Adds entry to the list of entries that will be released when the
|
||||
// RefHolder is destroyed. Each entry is released via a call to
|
||||
// parent_->DiscardEntryRefs.
|
||||
void AddRef(CompilationEntry* entry) {
|
||||
entries_.push_back(entry);
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "TpuCompilationCacheInterface::RefHolder";
|
||||
}
|
||||
|
||||
private:
|
||||
TpuCompilationCacheInterface* parent_; // Not owned.
|
||||
std::vector<CompilationEntry*> entries_;
|
||||
};
|
||||
|
||||
// The bulk of implementation of CompileIfKeyAbsent() with the exception
|
||||
// of unloading programs that corresponds to possibly removed cache
|
||||
// entries. The split helps to manage locking since we prefer to perform
|
||||
// unloading without holding extra locks.
|
||||
Status CompileIfKeyAbsentHelper(
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
const SessionMetadata* session_metadata,
|
||||
TpuCompilationRefHolder* per_step_ref_holder, int64* uid,
|
||||
std::vector<string>* proto_key, std::vector<bool>* may_modify_variables,
|
||||
std::vector<CompilationEntry*>* removed_entries,
|
||||
std::vector<std::shared_ptr<const xla::HloProto>>* hlo_metadata,
|
||||
const std::function<Status(TpuProgram*)>& compile_function);
|
||||
|
||||
// This is called by the cache when entry is marked for eviction; by
|
||||
// a RefHolder (via DiscardEntryRefs) when a step completes; and by
|
||||
// an EntryRefImpl when it is destroyed. Releases one reference to entry
|
||||
// if more than 1 remains. If only one reference is left, the entry is removed
|
||||
// from cache_ and is returned to the caller; which must eventually call
|
||||
// UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
|
||||
// to avoid holding the lock during program unloading.
|
||||
ABSL_MUST_USE_RESULT CompilationEntry* DiscardEntryRef(
|
||||
CompilationEntry* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
// Convenience method called by ~RefHolder without mu_ held. Calls
|
||||
// DiscardEntryRef on every element of entries.
|
||||
void DiscardEntryRefs(
|
||||
gtl::ArraySlice<CompilationEntry*> entries);
|
||||
|
||||
// Marks the oldest unmarked entry for eviction. Requires that there is at
|
||||
// least one such entry. In case the evicted entry had only 1 reference it
|
||||
// is removed from the cache and returned to the caller which must eventually
|
||||
// call UnloadAndDestroy.
|
||||
CompilationEntry* MarkOldestEntryForEviction()
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Updates datastructures to indicate that entry, which had been marked for
|
||||
// eviction, has been looked up. This is called by CompileIfKeyAbsent when an
|
||||
// entry is newly created, or an entry that has been marked for eviction but
|
||||
// not yet evicted is looked up.
|
||||
//
|
||||
// First the entry is unmarked for eviction, i.e. the cache gains a reference
|
||||
// to entry, entry's last_use field is set to be the most recent value of
|
||||
// use_counter_ and entries_by_last_use_ is updated accordingly.
|
||||
//
|
||||
// Next, the size of the cache is examined to see if any other entries need to
|
||||
// be marked for eviction now that entry has been unmarked. While the total
|
||||
// size of unmarked cached entries is greater than max_cache_size_, entries
|
||||
// are marked for eviction in LRU order. The most recently used entry is never
|
||||
// marked for eviction, so an entry larger than the max cache size will remain
|
||||
// in the cache until it is replaced by something else. In case some entries
|
||||
// actually were removed from the cache, they are a returned to the caller via
|
||||
// removed_entries. The caller must eventually delete them by calling
|
||||
// UnloadAndDestroy.
|
||||
void LookupEntryMarkedForEviction(
|
||||
CompilationEntry* entry, std::vector<CompilationEntry*>* removed_entries)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Removes the entry with given key from cache.
|
||||
size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Inserts the given key and entry to cache.
|
||||
void InsertEntry(const std::string& key,
|
||||
const TpuCompilationCacheKey& subgraph_key,
|
||||
CompilationEntry* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Returns the cache key matching given subgraph_key.
|
||||
std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Creates a new entry by running initialize_programs and places it in the
|
||||
// cache to be looked up by key. The new entry is in the 'marked for eviction'
|
||||
// state (not present in entries_by_last_use_) and the caller is expected to
|
||||
// call LookupEntryMarkedForEviction after InitializeEntry.
|
||||
//
|
||||
// **InitializeEntry releases mu_ during the call to initialize_programs.**
|
||||
CompilationEntry* InitializeEntry(
|
||||
const string& key,
|
||||
const std::function<Status(TpuProgram*)>& initialize_program,
|
||||
const TpuCompilationCacheKey& subgraph_key)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// Unloads the program associated with the entry from all local devices
|
||||
// and deletes the entry itself. It is assumed no one else has a reference
|
||||
// to it and all related keys had already been removed from the cache.
|
||||
// The call can perform device IO so no locks should be held while calling it.
|
||||
void UnloadAndDestroy(CompilationEntry* entry) ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// The maximum size of entries that are stored in the cache before entries are
|
||||
// marked for eviction.
|
||||
const int64 max_cache_size_;
|
||||
|
||||
mutable absl::Mutex mu_;
|
||||
// The total size of entries that are stored and not marked for eviction.
|
||||
int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// The total size of entries that are marked for eviction.
|
||||
int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// The value to assign to the last_use field of the next entry that is looked
|
||||
// up.
|
||||
int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// session_key_map_ and fingerprint_key_map_ are used for looking up the
|
||||
// cache_ key matching a given subgraph key. When doing a lookup, check
|
||||
// session_key_map_ first to avoid unnecessay fingerprint computation.
|
||||
// Map from key prefix + session_handle to a cache_ key.
|
||||
std::unordered_map<string, string> session_key_map_ ABSL_GUARDED_BY(mu_);
|
||||
|
||||
// Map from key prefix + fingerprint to a cache_ key.
|
||||
std::unordered_map<string, string> fingerprint_key_map_ ABSL_GUARDED_BY(mu_);
|
||||
|
||||
// All the subgraph entries that can be looked up in the cache. An entry is
|
||||
// marked for eviction iff it is present in cache_ and not in
|
||||
// entries_by_last_use_.
|
||||
std::unordered_map<string, CompilationEntry*> cache_store_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
|
||||
// All the subgraph entries that can be looked up in the cache, indexed by
|
||||
// uid.
|
||||
absl::node_hash_map<int64, CompilationEntry*> entries_by_uid_
|
||||
ABSL_GUARDED_BY(mu_);
|
||||
|
||||
// All the protos that can be looked up in the cache, indexed by proto
|
||||
// key. The value of the map is a subgraph and the index of the proto compiled
|
||||
// for that subgraph.
|
||||
std::unordered_map<string, std::pair<CompilationEntry*, int>>
|
||||
entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
|
||||
|
||||
// Map from last_use to entry, used to mark entries for eviction in LRU
|
||||
// order. If an entry's last_use counter is not present as a key in
|
||||
// entries_by_last_use_ then the entry has been marked for eviction.
|
||||
std::map<int64, CompilationEntry*> entries_by_last_use_ ABSL_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_
|
53
tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h
Normal file
53
tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
struct TpuCompilationCacheKey {
|
||||
// Prefix of the key.
|
||||
std::string prefix;
|
||||
|
||||
// A boolean flag to specify if `guaranteed_const` is used. Guarantee const is
|
||||
// normally used in TPU inference to avoid re-copying unchanged variables onto
|
||||
// the TPU device. It promises the value is identical for every execution in
|
||||
// the same session even if the actual value changes in later executions.
|
||||
bool has_guaranteed_const = false;
|
||||
|
||||
// Unique session identifier. It is set when `has_guaranteed_const` is true.
|
||||
std::string session_handle;
|
||||
|
||||
// Fingerprint of `guaranteed_const` value. It is set when the value of the
|
||||
// `has_guaranteed_const` is true. Produce the value when necessary.
|
||||
std::function<std::string()> guaranteed_const_fingerprint;
|
||||
|
||||
// A more verbose key for debugging purpose.
|
||||
std::string debug_string;
|
||||
|
||||
explicit TpuCompilationCacheKey() {}
|
||||
explicit TpuCompilationCacheKey(const std::string& p) : prefix(p) {}
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_
|
93
tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc
Normal file
93
tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc
Normal file
@ -0,0 +1,93 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
|
||||
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
namespace {
|
||||
class CompilationCacheFetchTargetUtility {
|
||||
public:
|
||||
CompilationCacheFetchTargetUtility()
|
||||
: names_({"Invalid", "Main", "Sharding", "Unsharding"}) {}
|
||||
|
||||
std::string name(CompilationCacheFetchTarget target) const {
|
||||
return names_[static_cast<int>(target)];
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<std::string> names_;
|
||||
};
|
||||
|
||||
std::string GetName(CompilationCacheFetchTarget target) {
|
||||
static const auto* util = new CompilationCacheFetchTargetUtility();
|
||||
return util->name(target);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup(
|
||||
TpuCompilationCacheInterface* cache)
|
||||
: cache_(cache) {}
|
||||
|
||||
TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() {
|
||||
cache_->Unref();
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheLocalLookup::Lookup(
|
||||
const string& proto_key, std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup",
|
||||
/*level=*/2);
|
||||
Status s = cache_->Lookup(proto_key, entry);
|
||||
VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status "
|
||||
<< s;
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
s = cache_->ToSubEntryRef(entry->get(), fetch_target);
|
||||
|
||||
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
||||
<< s;
|
||||
return s;
|
||||
}
|
||||
|
||||
Status TpuCompilationCacheLocalLookup::Lookup(
|
||||
int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) {
|
||||
profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid",
|
||||
/*level=*/2);
|
||||
Status s = cache_->Lookup(uid, proto_index, entry);
|
||||
VLOG(1) << "Looked up uid " << uid << ", index " << proto_index
|
||||
<< " in local subgraph cache status " << s;
|
||||
if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
s = cache_->ToSubEntryRef(entry->get(), fetch_target);
|
||||
VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status "
|
||||
<< s;
|
||||
return s;
|
||||
}
|
||||
|
||||
string TpuCompilationCacheLocalLookup::DebugString() const {
|
||||
return "TpuCompilationCacheLocalLookup";
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
99
tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
Normal file
99
tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h
Normal file
@ -0,0 +1,99 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// Base class allowing Execute Ops to look up ISA protos. Different subclasses
|
||||
// are used when the execute Op is in the same address space as the compile Op,
|
||||
// and when they need to communicate over RPC.
|
||||
class TpuCompilationCacheLookup : public ResourceBase {
|
||||
public:
|
||||
~TpuCompilationCacheLookup() override = default;
|
||||
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
// the subgraph represented by key. On success a wrapper for the proto is
|
||||
// returned in program. The wrapper is guaranteed to be valid only during the
|
||||
// execution of the Op requesting the proto.
|
||||
//
|
||||
// Only one of the main, sharding, unsharding entries is fetched, as specified
|
||||
// in fetch_target.
|
||||
//
|
||||
// If the compilation does not create sharding/unsharding programs, but the
|
||||
// fetch_target requests one of them, then after this call
|
||||
// (*entry)->get().get_executable() will return nullptr.
|
||||
virtual Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) = 0;
|
||||
|
||||
virtual Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
return Lookup(proto_key, std::move(entry),
|
||||
CompilationCacheFetchTarget::MAIN);
|
||||
}
|
||||
|
||||
// Looks up an executable corresponding to the model-parallel core index of
|
||||
// the subgraph represented by uid. On success a wrapper for the proto is
|
||||
// returned in program. The wrapper is guaranteed to be valid only during the
|
||||
// execution of the Op requesting the proto.
|
||||
virtual Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) = 0;
|
||||
|
||||
virtual Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry) {
|
||||
return Lookup(uid, proto_index, std::move(entry),
|
||||
CompilationCacheFetchTarget::MAIN);
|
||||
}
|
||||
};
|
||||
|
||||
// Forward declaration to break cycle dependency graph.
|
||||
class TpuCompilationCacheInterface;
|
||||
|
||||
// Class for looking up ISA protos when the execute and compile Op are in the
|
||||
// same address space. The proto is simply looked up in the compilation cache,
|
||||
// without any serialization taking place.
|
||||
class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup {
|
||||
public:
|
||||
explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache);
|
||||
~TpuCompilationCacheLocalLookup() override;
|
||||
|
||||
Status Lookup(const string& proto_key,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) override;
|
||||
|
||||
Status Lookup(int64 uid, int proto_index,
|
||||
std::unique_ptr<CompilationCacheEntryRef>* entry,
|
||||
CompilationCacheFetchTarget fetch_target) override;
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
private:
|
||||
// The subgraph compilation cache, in the same process address space where the
|
||||
// lookups are happening.
|
||||
TpuCompilationCacheInterface* cache_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_
|
32
tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc
Normal file
32
tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
/* static */
|
||||
void TpuCompilationCacheMetrics::IncrementCacheLookupCount(
|
||||
bool is_cache_hit, absl::string_view session_name) {
|
||||
// A placeholder for tracking metrics.
|
||||
}
|
||||
|
||||
/* static */
|
||||
void TpuCompilationCacheMetrics::SetCacheEntryCount(int64 count) {
|
||||
// A placeholder for tracking metrics.
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
38
tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h
Normal file
38
tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// Tracks Tpu compilation cache metrics.
|
||||
class TpuCompilationCacheMetrics {
|
||||
public:
|
||||
// Increments the number of cache lookup count.
|
||||
static void IncrementCacheLookupCount(bool is_cache_hit,
|
||||
absl::string_view session_name);
|
||||
|
||||
// Sets the total count of cache entries.
|
||||
static void SetCacheEntryCount(int64 count);
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_
|
144
tensorflow/core/tpu/kernels/tpu_compile.proto
Normal file
144
tensorflow/core/tpu/kernels/tpu_compile.proto
Normal file
@ -0,0 +1,144 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.tpu;
|
||||
|
||||
import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
|
||||
import "tensorflow/compiler/xla/service/hlo.proto";
|
||||
import "tensorflow/compiler/xla/xla_data.proto";
|
||||
import "tensorflow/core/framework/tensor.proto";
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
import "tensorflow/core/protobuf/tpu/compile_metadata.proto";
|
||||
import "tensorflow/core/tpu/kernels/tpu_executable_info.proto";
|
||||
|
||||
message PerCoreVariableIndices {
|
||||
// For each resource variable output, what was the index of the corresponding
|
||||
// input and was it updated? The indices are sorted by input order.
|
||||
repeated TPUExecutableInfoProto.UpdateIndexPair variable_indices = 1;
|
||||
}
|
||||
|
||||
message PerCoreArgShapes {
|
||||
// Argument shapes for each Tpu core.
|
||||
repeated xla.ShapeProto shapes = 1;
|
||||
}
|
||||
|
||||
message PerCoreOutputShapes {
|
||||
// Output shapes for each Tpu core.
|
||||
repeated xla.ShapeProto shapes = 1;
|
||||
}
|
||||
|
||||
message OutputDescriptionProto {
|
||||
// Type and shape of the output. The shape is the unflattened shape.
|
||||
// When `type` is DT_RESOURCE, `shape` is the shape of the resource
|
||||
// variable's value.
|
||||
tensorflow.DataType type = 1;
|
||||
tensorflow.TensorShapeProto shape = 2;
|
||||
|
||||
// Constant output value, if known to be constant at JIT compilation time.
|
||||
// 'Tensor' is in host memory.
|
||||
bool is_constant = 3;
|
||||
tensorflow.TensorProto constant_value = 4;
|
||||
|
||||
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
|
||||
// the index of the input that contains the resource.
|
||||
int32 input_index = 5;
|
||||
|
||||
// Whether this output is a TensorList.
|
||||
bool is_tensor_list = 6;
|
||||
}
|
||||
|
||||
// Describes a variable write side effect of the computation.
|
||||
message ResourceUpdateProto {
|
||||
// Index of the input that contains the variable resource to write to.
|
||||
int32 input_index = 1;
|
||||
|
||||
// Type and shape of the tensor to be written back.
|
||||
// The `shape` field has the same meaning as the Argument::shape field.
|
||||
tensorflow.DataType type = 2;
|
||||
tensorflow.TensorShapeProto shape = 3;
|
||||
|
||||
// Was the value of the variable modified by the computation?
|
||||
// (Always true, unless `return_updated_values_for_all_resources` is true.)
|
||||
bool modified = 4;
|
||||
|
||||
// If the resource is a TensorArray, the set of gradients read or written.
|
||||
map<string, bool> tensor_array_gradients_accessed = 5;
|
||||
}
|
||||
|
||||
// Describes the result of a XLA Compiler compilation.
|
||||
message XlaCompilationResultProto {
|
||||
// Vector that maps from the parameters of the XLA computation to their
|
||||
// original argument positions. To handle compile-time constant inputs, the
|
||||
// parameters to the XLA computation may be a subset of the original
|
||||
// arguments. The relative ordering of parameters are maintained.
|
||||
repeated int32 input_mappings = 1;
|
||||
|
||||
// Input shapes of the computation. If we are flattening inputs, these are
|
||||
// the flattened shapes.
|
||||
repeated xla.ShapeProto xla_input_shapes = 2;
|
||||
|
||||
// Output shape in XLA format. The output shape is always a tuple. If we
|
||||
// are flattening outputs, these are the flattened shapes.
|
||||
xla.ShapeProto xla_output_shape = 3;
|
||||
|
||||
// TensorFlow shapes of outputs, together with the values of any
|
||||
// constant arguments. Vector indexed by Tensorflow _Retval number,
|
||||
// containing both constant and non-constant results.
|
||||
repeated OutputDescriptionProto outputs = 4;
|
||||
|
||||
// TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
|
||||
// matching RecvAtHost/SendFromHost Ops in the outer graph.
|
||||
tf2xla.HostComputeMetadata host_compute_metadata = 5;
|
||||
|
||||
// Resources whose values were updated by the computation, ordered
|
||||
// by return value position (which is the same as the order the resources
|
||||
// were passed as arguments). Resource updates follow the non-constant
|
||||
// results in the outputs of XLA computation.
|
||||
repeated ResourceUpdateProto resource_updates = 6;
|
||||
|
||||
// The XLA computation built from the tensorflow subgraph.
|
||||
xla.HloModuleProto computation = 7;
|
||||
}
|
||||
|
||||
// TpuAotCompilationRequestProto represents a compilation request for performing
|
||||
// ahead-of-time (AOT) compilation of XLA Computations into XLA HLO IR.
|
||||
message TpuAotCompilationRequestProto {
|
||||
// A set of HLO module built to run concurrently
|
||||
// across different devices.
|
||||
xla.HloModuleGroupProto hlo_module_group = 1;
|
||||
|
||||
// Compilation metadata.
|
||||
TPUCompileMetadataProto metadata = 2;
|
||||
|
||||
// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
|
||||
// represents the device ids assigned to a set of replicated computations.
|
||||
// See xla::DeviceAssignment class comment for more details.
|
||||
xla.DeviceAssignmentProto device_assignment = 3;
|
||||
|
||||
// Per TPU core program arguments shapes.
|
||||
repeated PerCoreArgShapes per_core_arg_shapes = 4;
|
||||
|
||||
// Per TPU core program outputs shapes.
|
||||
repeated PerCoreOutputShapes per_core_output_shapes = 5;
|
||||
|
||||
// Per TPU core information containing what was the index of the corresponding
|
||||
// input and if whether it was updated. The indices are sorted by input order.
|
||||
repeated PerCoreVariableIndices per_core_variable_indices = 6;
|
||||
|
||||
// XLA compiler compilation result.
|
||||
XlaCompilationResultProto compilation_result = 7;
|
||||
}
|
119
tensorflow/core/tpu/kernels/tpu_compile_c_api.h
Normal file
119
tensorflow/core/tpu/kernels/tpu_compile_c_api.h
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
enum TpuCoreTypeEnum {
|
||||
kTensorCore,
|
||||
kEmbeddingV1,
|
||||
kEmbeddingV2,
|
||||
};
|
||||
|
||||
typedef struct XLA_TpuProgram XLA_TpuProgram;
|
||||
|
||||
// Property for creating compilation cache key.
|
||||
struct CompilationCacheKeyProperty {
|
||||
const char* config_prefix;
|
||||
const char* shapes_prefix;
|
||||
const char* function_name;
|
||||
const char* mlir_module;
|
||||
const int32_t* device_ids;
|
||||
size_t device_ids_size;
|
||||
int32_t guaranteed_constants_size;
|
||||
uint64_t function_library_fingerprint;
|
||||
int32_t num_cores_per_replica;
|
||||
int32_t num_replicas;
|
||||
const XLA_TpuMeshState* mesh_state;
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Creates a new TPU program.
|
||||
XLA_TpuProgram* TpuProgram_New();
|
||||
|
||||
// Destroys the `tpu_program`.
|
||||
void TpuProgram_Free(XLA_TpuProgram* tpu_program);
|
||||
|
||||
|
||||
// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
|
||||
// destroyed, it is in an unusable state.
|
||||
void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
|
||||
SE_Status* status);
|
||||
|
||||
// Gets TPU program size in bytes from the `tpu_program`.
|
||||
int64_t TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Logs the summary of current memory state snapshot of the `tpu_program`.
|
||||
bool TpuProgram_LogProgramMemorySummary(const XLA_TpuProgram* tpu_program);
|
||||
|
||||
// Gets TPU program executable info from the `tpu_program`.
|
||||
void TpuProgram_GetExecutableInfo(const XLA_TpuProgram* tpu_program,
|
||||
TpuSerializedProto* executable_info);
|
||||
|
||||
// Gets host transfer info proto.
|
||||
void TpuProgram_GetHostTransferInfo(
|
||||
const XLA_TpuProgram* tpu_program,
|
||||
TpuSerializedProto* host_transfer_info);
|
||||
|
||||
// Gets HLO metadata proto.
|
||||
void TpuProgram_GetHloMetadata(const XLA_TpuProgram* tpu_program,
|
||||
TpuSerializedProto* hlo_metadata);
|
||||
|
||||
// Returns the number of available TPU core count.
|
||||
int TpuTopology_AvailableCoreCount(const XLA_TpuMeshState* mesh_state,
|
||||
TpuCoreTypeEnum tpu_core_type);
|
||||
|
||||
// Creates a unique compilation cache `key` used for `put` and `get` operations.
|
||||
// Returned buffer is heap-allocated and must be owned.
|
||||
const char* TpuCompile_CreateCompilationCacheKey(
|
||||
CompilationCacheKeyProperty property);
|
||||
|
||||
// Creates a guaranteed const fingerprint. Guarantee const is normally used in
|
||||
// TPU inference to avoid re-copying unchanged variables onto the TPU device.
|
||||
// It promises the value is identical for every execution in the same session
|
||||
// even if the actual value changes in later executions.
|
||||
uint64_t TpuCompile_CreateGuaranteedConstFingerprint(uint64_t fingerprint,
|
||||
const char* data,
|
||||
size_t size);
|
||||
|
||||
// Checks if whether a TPU compilation is enabled.
|
||||
bool TpuCompile_IsTpuCompilationEnabled();
|
||||
|
||||
// Executes the computations using XLA TPU compiler and returns TPU programs
|
||||
// ready for execution.
|
||||
void TpuCompile_CompileAheadOfTime(
|
||||
TpuSerializedProto aot_compilation_request,
|
||||
XLA_TpuProgram** tpu_programs[],
|
||||
size_t* count, SE_Status* status);
|
||||
|
||||
// Builds `DeviceAssignment` from `TpuCompileMetadata` serialized proto.
|
||||
void TpuCompile_BuildXLADeviceAssignment(
|
||||
TpuSerializedProto serialized_tpu_compile_metadata,
|
||||
const XLA_TpuMeshState* mesh_state,
|
||||
TpuSerializedProto* serialized_device_assignment, SE_Status* status);
|
||||
|
||||
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
|
||||
void TpuCompile_ToTpuShapeRepresentation(
|
||||
TpuSerializedProto serialized_xla_shape, int data_type,
|
||||
bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape,
|
||||
SE_Status* status);
|
||||
|
||||
} // extern "C"
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
|
42
tensorflow/core/tpu/kernels/tpu_compile_op_options.cc
Normal file
42
tensorflow/core/tpu/kernels/tpu_compile_op_options.cc
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
namespace {
|
||||
static bool tpu_compilation_cancellation_terminates_process = true;
|
||||
static bool tpu_compilation_failure_closes_chips = true;
|
||||
} // namespace
|
||||
|
||||
void SetTpuCompilationCancellationTerminatesProcess(bool b) {
|
||||
tpu_compilation_cancellation_terminates_process = b;
|
||||
}
|
||||
|
||||
bool TpuCompilationCancellationTerminatesProcess() {
|
||||
return tpu_compilation_cancellation_terminates_process;
|
||||
}
|
||||
|
||||
void SetTpuCompilationFailureClosesChips(bool value) {
|
||||
tpu_compilation_failure_closes_chips = value;
|
||||
}
|
||||
|
||||
bool TpuCompilationFailureClosesChips() {
|
||||
return tpu_compilation_failure_closes_chips;
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
42
tensorflow/core/tpu/kernels/tpu_compile_op_options.h
Normal file
42
tensorflow/core/tpu/kernels/tpu_compile_op_options.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Setter and getter that determine how TPUCompile responds to cancelled
|
||||
// compilation. By default this is true, meaning cancelled compilation will
|
||||
// abort the process, since that's the only mechanism we have available.
|
||||
//
|
||||
// Setting this to false allows the process to remain alive, and should only be
|
||||
// used in tests.
|
||||
void SetTpuCompilationCancellationTerminatesProcess(bool b);
|
||||
bool TpuCompilationCancellationTerminatesProcess();
|
||||
|
||||
// Setter and getter that determine whether TPU compilation failure will cause
|
||||
// chips to close. By default this is true, it is suitable for training. For
|
||||
// inference, we never want servers to die and thus chips will keep alive.
|
||||
// See b/109873767.
|
||||
void SetTpuCompilationFailureClosesChips(bool value);
|
||||
bool TpuCompilationFailureClosesChips();
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_
|
439
tensorflow/core/tpu/kernels/tpu_compile_op_support.cc
Normal file
439
tensorflow/core/tpu/kernels/tpu_compile_op_support.cc
Normal file
@ -0,0 +1,439 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr;
|
||||
using xla::ComputationLayout;
|
||||
using xla::DebugOptions;
|
||||
using xla::DeviceAssignment;
|
||||
using xla::HloModuleConfig;
|
||||
using xla::HloSharding;
|
||||
using xla::InvalidArgument;
|
||||
using xla::ProgramShape;
|
||||
using xla::Shape;
|
||||
using xla::ShapeTree;
|
||||
using xla::ShapeUtil;
|
||||
|
||||
Status ValidateResultShape(const Shape& client_shape,
|
||||
const Shape& result_shape) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla::ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
|
||||
if (!xla::ShapeUtil::Compatible(client_shape, result_shape)) {
|
||||
return InvalidArgument(
|
||||
"Shape used to set computation result layout %s is not compatible "
|
||||
"with result shape %s",
|
||||
xla::ShapeUtil::HumanStringWithLayout(client_shape),
|
||||
xla::ShapeUtil::HumanString(result_shape));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||
const ProgramShape& program_shape, absl::Span<const Shape> argument_shapes,
|
||||
absl::optional<const Shape> result_layout,
|
||||
absl::optional<const DeviceAssignment> device_assignment, int replica_count,
|
||||
int num_partitions, const DebugOptions* debug_options, const int* seed,
|
||||
const int* launch_id, const bool* alias_passthrough_params,
|
||||
const xla::FusionConfigCollection* fusion_config_collection,
|
||||
const std::vector<std::vector<bool>>* fusion_config) {
|
||||
auto config = absl::make_unique<HloModuleConfig>(program_shape);
|
||||
ComputationLayout* computation_layout =
|
||||
config->mutable_entry_computation_layout();
|
||||
if (program_shape.parameters_size() != argument_shapes.size()) {
|
||||
return InvalidArgument("computation takes %d parameters, but %u given",
|
||||
program_shape.parameters_size(),
|
||||
argument_shapes.size());
|
||||
}
|
||||
for (int i = 0; i < argument_shapes.size(); ++i) {
|
||||
// Verify that shape of arguments matches the shape of the arguments in the
|
||||
// ProgramShape.
|
||||
if (!ShapeUtil::Compatible(argument_shapes[i],
|
||||
program_shape.parameters(i))) {
|
||||
return InvalidArgument(
|
||||
"Argument does not match shape of computation parameter %d: want "
|
||||
"%s, got %s",
|
||||
i, ShapeUtil::HumanString(program_shape.parameters(i)),
|
||||
ShapeUtil::HumanString(argument_shapes[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
argument_shapes[i]));
|
||||
}
|
||||
|
||||
if (result_layout.has_value()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateResultShape(result_layout.value(), program_shape.result()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
result_layout.value()));
|
||||
} else {
|
||||
// If the result layout is not set, then choose the default.
|
||||
computation_layout->mutable_result_layout()->SetToDefaultLayout();
|
||||
}
|
||||
|
||||
config->set_replica_count(replica_count);
|
||||
config->set_num_partitions(num_partitions);
|
||||
if (seed != nullptr) {
|
||||
config->set_seed(*seed);
|
||||
}
|
||||
if (launch_id != nullptr) {
|
||||
config->set_launch_id(*launch_id);
|
||||
}
|
||||
if (debug_options != nullptr) {
|
||||
config->set_debug_options(*debug_options);
|
||||
} else {
|
||||
config->set_debug_options(xla::GetDebugOptionsFromFlags());
|
||||
}
|
||||
|
||||
// TODO(henrytan): set intra_op_parallelism_threads.
|
||||
// Reference:
|
||||
// tensorflow/compiler/xla/service/service.cc?l=324.
|
||||
|
||||
if (device_assignment.has_value()) {
|
||||
config->set_static_device_assignment(device_assignment.value());
|
||||
}
|
||||
|
||||
if (alias_passthrough_params != nullptr) {
|
||||
config->set_alias_passthrough_params(*alias_passthrough_params);
|
||||
}
|
||||
|
||||
if (fusion_config_collection != nullptr && fusion_config != nullptr &&
|
||||
*fusion_config_collection != xla::FusionConfigCollection::kOff) {
|
||||
config->set_fusion_config_collection(*fusion_config_collection);
|
||||
*config->mutable_fusion_config() = *fusion_config;
|
||||
}
|
||||
|
||||
return std::move(config);
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<xla::HloModuleConfig>> CreateModuleConfig(
|
||||
const xla::ProgramShape& program_shape,
|
||||
absl::Span<const Shape> argument_shapes,
|
||||
absl::optional<const Shape> result_layout,
|
||||
absl::optional<const DeviceAssignment> device_assignment, int replica_count,
|
||||
int num_partitions, const DebugOptions* debug_options) {
|
||||
return CreateModuleConfig(program_shape, argument_shapes, result_layout,
|
||||
device_assignment, replica_count, num_partitions,
|
||||
debug_options, /*seed=*/nullptr,
|
||||
/*launch_id=*/nullptr,
|
||||
/*alias_passthrough_params=*/nullptr,
|
||||
/*fusion_config_collection=*/nullptr,
|
||||
/*fusion_config=*/nullptr);
|
||||
}
|
||||
|
||||
ShapeTree<HloSharding> GetSubtree(
|
||||
const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) {
|
||||
ShapeTree<HloSharding> element_shape_tree(
|
||||
xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(),
|
||||
element_index),
|
||||
HloSharding::Replicate());
|
||||
|
||||
xla::ShapeIndex src_index;
|
||||
src_index.push_back(element_index);
|
||||
element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {});
|
||||
return element_shape_tree;
|
||||
}
|
||||
|
||||
Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding,
|
||||
int64 device) {
|
||||
if (shape.IsTuple()) {
|
||||
ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape);
|
||||
std::vector<Shape> arg_shapes;
|
||||
for (int64 i = 0; i < xla::ShapeUtil::TupleElementCount(shape); ++i) {
|
||||
Shape element_shape = xla::ShapeUtil::GetTupleElementShape(shape, i);
|
||||
HloSharding element_sharding = tuple_shape_tree.element({i});
|
||||
if (element_shape.IsTuple()) {
|
||||
element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i));
|
||||
}
|
||||
if (element_sharding.UsesDevice(device)) {
|
||||
arg_shapes.push_back(
|
||||
GetPerDeviceShape(element_shape, element_sharding, device));
|
||||
}
|
||||
}
|
||||
return xla::ShapeUtil::MakeTupleShape(arg_shapes);
|
||||
}
|
||||
|
||||
if (sharding.IsTileMaximal()) {
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::vector<int64> dimensions;
|
||||
std::vector<int64> offset = sharding.TileOffsetForDevice(shape, device);
|
||||
std::vector<int64> limit = sharding.TileLimitForDevice(shape, device);
|
||||
for (int64 i = 0; i < limit.size(); ++i) {
|
||||
dimensions.push_back(limit[i] - offset[i]);
|
||||
}
|
||||
if (shape.has_layout()) {
|
||||
return xla::ShapeUtil::MakeShapeWithLayout(shape.element_type(), dimensions,
|
||||
shape.layout().minor_to_major());
|
||||
}
|
||||
return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions);
|
||||
}
|
||||
|
||||
Status AddVariableUpdatesToCores(
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
const std::vector<ShardingAndIndex>& arg_core_mapping,
|
||||
std::vector<bool>* may_modify_variables,
|
||||
std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
|
||||
std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices) {
|
||||
// Add all variables to the corresponding core.
|
||||
may_modify_variables->resize(metadata.num_cores_per_replica(), false);
|
||||
int resource_update_pos = 0;
|
||||
for (int i = 0; i < metadata.args_size(); ++i) {
|
||||
const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
|
||||
if (proto_arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
|
||||
const auto& sharding = proto_arg.sharding();
|
||||
bool updated = false;
|
||||
if (resource_update_pos < compilation_result.resource_updates.size()) {
|
||||
const XlaCompiler::ResourceUpdate& update =
|
||||
compilation_result.resource_updates[resource_update_pos];
|
||||
if (update.input_index == i) {
|
||||
updated = true;
|
||||
int pos = compilation_result.outputs.size() + resource_update_pos;
|
||||
xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
|
||||
compilation_result.xla_output_shape, pos);
|
||||
auto add_to_core = [&](int64 core, const xla::Shape& per_core_shape) {
|
||||
(*per_core_output_shapes)[core].push_back(per_core_shape);
|
||||
(*may_modify_variables)[core] =
|
||||
(*may_modify_variables)[core] || update.modified;
|
||||
};
|
||||
if (sharding.type() == xla::OpSharding::MAXIMAL) {
|
||||
add_to_core(sharding.tile_assignment_devices(0), shape);
|
||||
} else if (sharding.type() == xla::OpSharding::OTHER) {
|
||||
auto sharding_or =
|
||||
xla::HloSharding::FromProto(proto_arg.sharding());
|
||||
TF_RET_CHECK(sharding_or.ok());
|
||||
for (int64 core : proto_arg.sharding().tile_assignment_devices()) {
|
||||
xla::Shape per_core_shape =
|
||||
GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
|
||||
add_to_core(core, per_core_shape);
|
||||
}
|
||||
} else {
|
||||
TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
|
||||
for (int64 core = 0; core < metadata.num_cores_per_replica();
|
||||
++core) {
|
||||
add_to_core(core, shape);
|
||||
}
|
||||
}
|
||||
++resource_update_pos;
|
||||
}
|
||||
}
|
||||
if (sharding.type() == xla::OpSharding::MAXIMAL) {
|
||||
(*per_core_variable_indices)[sharding.tile_assignment_devices(0)]
|
||||
.push_back(
|
||||
std::pair<int, bool>(arg_core_mapping[i].indices[0], updated));
|
||||
} else if (sharding.type() == xla::OpSharding::OTHER) {
|
||||
for (int core : sharding.tile_assignment_devices()) {
|
||||
(*per_core_variable_indices)[core].push_back(
|
||||
std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
|
||||
}
|
||||
} else {
|
||||
TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
|
||||
for (int64 core = 0; core < metadata.num_cores_per_replica(); ++core) {
|
||||
(*per_core_variable_indices)[core].push_back(
|
||||
std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeOutputShapesForEachCore(
|
||||
const tpu::TPUCompileMetadataProto& metadata,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
std::vector<std::vector<xla::Shape>>* per_core_output_shapes) {
|
||||
for (int i = 0; i < metadata.retvals_size(); ++i) {
|
||||
const tpu::TPUCompileMetadataProto::Retval& retval = metadata.retvals(i);
|
||||
TF_RET_CHECK(!compilation_result.outputs[i].is_constant)
|
||||
<< "TPU compilation output " << i
|
||||
<< " has a compile-time constant value. "
|
||||
"This should never happen.";
|
||||
|
||||
xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
|
||||
compilation_result.xla_output_shape, i);
|
||||
auto add_shape_to_core = [&](int core, xla::Shape per_core_shape) {
|
||||
(*per_core_output_shapes)[core].push_back(std::move(per_core_shape));
|
||||
};
|
||||
if (retval.sharding().type() == xla::OpSharding::MAXIMAL) {
|
||||
add_shape_to_core(retval.sharding().tile_assignment_devices(0),
|
||||
std::move(shape));
|
||||
} else if (retval.sharding().type() == xla::OpSharding::OTHER) {
|
||||
auto sharding_or = xla::HloSharding::FromProto(retval.sharding());
|
||||
TF_RET_CHECK(sharding_or.ok());
|
||||
for (int64 core : retval.sharding().tile_assignment_devices()) {
|
||||
xla::Shape per_core_shape =
|
||||
GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
|
||||
add_shape_to_core(core, std::move(per_core_shape));
|
||||
}
|
||||
} else {
|
||||
TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED)
|
||||
<< "Not all of the constant tensors were consumed.";
|
||||
for (int core = 0; core < per_core_output_shapes->size(); ++core) {
|
||||
add_shape_to_core(core, shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateHloModules(
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const tensorflow::XlaCompiler::CompilationResult& compilation_result,
|
||||
const absl::optional<xla::DeviceAssignment>& device_assignment,
|
||||
std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules) {
|
||||
TF_RET_CHECK(
|
||||
compilation_result.computation->proto().has_host_program_shape());
|
||||
|
||||
auto debug_options = xla::DebugOptions();
|
||||
debug_options.set_xla_step_marker_location(metadata.step_marker_location());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<xla::HloModuleConfig> module_config,
|
||||
CreateModuleConfig(
|
||||
xla::ProgramShape(
|
||||
compilation_result.computation->proto().host_program_shape()),
|
||||
compilation_result.xla_input_shapes,
|
||||
compilation_result.xla_output_shape, device_assignment,
|
||||
metadata.num_replicas(), metadata.num_cores_per_replica(),
|
||||
&debug_options));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<xla::HloModule> hlo_module,
|
||||
xla::HloModule::CreateFromProto(compilation_result.computation->proto(),
|
||||
*module_config));
|
||||
DumpHloModuleIfEnabled(*hlo_module, "before_optimizations");
|
||||
hlo_modules->push_back(std::move(hlo_module));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaCompilationResultProto SerializeCompilationResult(
|
||||
const XlaCompiler::CompilationResult& compilation_result) {
|
||||
XlaCompilationResultProto compilation_result_proto;
|
||||
for (int input_mapping : compilation_result.input_mapping) {
|
||||
compilation_result_proto.add_input_mappings(input_mapping);
|
||||
}
|
||||
|
||||
for (const Shape& input_shape : compilation_result.xla_input_shapes) {
|
||||
*(compilation_result_proto.add_xla_input_shapes()) = input_shape.ToProto();
|
||||
}
|
||||
*(compilation_result_proto.mutable_xla_output_shape()) =
|
||||
compilation_result.xla_output_shape.ToProto();
|
||||
|
||||
for (const XlaCompiler::OutputDescription& output_description :
|
||||
compilation_result.outputs) {
|
||||
auto* new_output = compilation_result_proto.add_outputs();
|
||||
new_output->set_type(output_description.type);
|
||||
output_description.shape.AsProto(new_output->mutable_shape());
|
||||
new_output->set_is_constant(output_description.is_constant);
|
||||
output_description.constant_value.AsProtoField(
|
||||
new_output->mutable_constant_value());
|
||||
new_output->set_input_index(output_description.input_index);
|
||||
new_output->set_is_tensor_list(output_description.is_tensor_list);
|
||||
}
|
||||
|
||||
*compilation_result_proto.mutable_host_compute_metadata() =
|
||||
compilation_result.host_compute_metadata;
|
||||
|
||||
for (const XlaCompiler::ResourceUpdate& resource_update :
|
||||
compilation_result.resource_updates) {
|
||||
auto* new_resource_update = compilation_result_proto.add_resource_updates();
|
||||
new_resource_update->set_input_index(resource_update.input_index);
|
||||
new_resource_update->set_type(resource_update.type);
|
||||
resource_update.shape.AsProto(new_resource_update->mutable_shape());
|
||||
new_resource_update->set_modified(resource_update.modified);
|
||||
for (const std::string& gradient_access :
|
||||
resource_update.tensor_array_gradients_accessed) {
|
||||
new_resource_update->mutable_tensor_array_gradients_accessed()->insert(
|
||||
{gradient_access, true});
|
||||
}
|
||||
}
|
||||
|
||||
if (compilation_result.computation != nullptr) {
|
||||
*compilation_result_proto.mutable_computation() =
|
||||
compilation_result.computation->proto();
|
||||
}
|
||||
|
||||
return compilation_result_proto;
|
||||
}
|
||||
|
||||
StatusOr<TpuAotCompilationRequestProto> CreateTpuAotCompilationRequest(
|
||||
const xla::HloModuleGroup& module_group,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
|
||||
const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
|
||||
const std::vector<std::vector<std::pair<int, bool>>>&
|
||||
per_core_variable_indices,
|
||||
const absl::optional<xla::DeviceAssignment>& device_assignment) {
|
||||
VLOG(1) << "CreateTpuAotCompilationRequest.";
|
||||
TpuAotCompilationRequestProto aot_request;
|
||||
*(aot_request.mutable_hlo_module_group()) = module_group.ToProto();
|
||||
*(aot_request.mutable_metadata()) = metadata;
|
||||
if (device_assignment.has_value()) {
|
||||
xla::DeviceAssignmentProto device_assignment_proto;
|
||||
Status status = device_assignment->Serialize(&device_assignment_proto);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
*(aot_request.mutable_device_assignment()) = device_assignment_proto;
|
||||
}
|
||||
|
||||
for (const auto& arg_shapes : per_core_arg_shapes) {
|
||||
auto* new_shape_list = aot_request.add_per_core_arg_shapes();
|
||||
for (const auto& arg_shape : arg_shapes) {
|
||||
*new_shape_list->add_shapes() = arg_shape.ToProto();
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& output_shapes : per_core_output_shapes) {
|
||||
auto* new_shape_list = aot_request.add_per_core_output_shapes();
|
||||
for (const auto& output_shape : output_shapes) {
|
||||
*new_shape_list->add_shapes() = output_shape.ToProto();
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& variable_indices : per_core_variable_indices) {
|
||||
auto* new_list = aot_request.add_per_core_variable_indices();
|
||||
for (const auto& variable_index : variable_indices) {
|
||||
auto* core_index = new_list->add_variable_indices();
|
||||
core_index->set_index(variable_index.first);
|
||||
core_index->set_updated(variable_index.second);
|
||||
}
|
||||
}
|
||||
|
||||
XlaCompilationResultProto compilation_result_proto =
|
||||
SerializeCompilationResult(compilation_result);
|
||||
*aot_request.mutable_compilation_result() = compilation_result_proto;
|
||||
|
||||
VLOG(1) << "TpuAotCompilationRequest:\n" << aot_request.DebugString();
|
||||
return aot_request;
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
122
tensorflow/core/tpu/kernels/tpu_compile_op_support.h
Normal file
122
tensorflow/core/tpu/kernels/tpu_compile_op_support.h
Normal file
@ -0,0 +1,122 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
namespace se = ::stream_executor;
|
||||
|
||||
// Describes the position of an argument or return value after the computation
|
||||
// has been partitioned into cores.
|
||||
struct ShardingAndIndex {
|
||||
// Sharding across cores.
|
||||
::xla::OpSharding sharding;
|
||||
// Argument/return value number. If sharding is single-core, `indices` has a
|
||||
// single element; otherwise, it has num_cores elements.
|
||||
std::vector<int> indices;
|
||||
};
|
||||
|
||||
// TODO(b/158279168): Dedup with internal version.
|
||||
// Return the per-device shape for a `shape` with a given `sharding`.
|
||||
xla::Shape GetPerDeviceShape(const xla::Shape& shape,
|
||||
const xla::HloSharding& sharding,
|
||||
int64 device);
|
||||
|
||||
stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
|
||||
CreateModuleConfig(
|
||||
const xla::ProgramShape& program_shape,
|
||||
absl::Span<const xla::Shape> argument_shapes,
|
||||
absl::optional<const xla::Shape> result_layout,
|
||||
absl::optional<const xla::DeviceAssignment> device_assignment,
|
||||
int replica_count, int num_partitions,
|
||||
const xla::DebugOptions* debug_options, const int* seed,
|
||||
const int* launch_id, const bool* alias_passthrough_params,
|
||||
const xla::FusionConfigCollection* fusion_config_collection,
|
||||
const std::vector<std::vector<bool>>* fusion_config);
|
||||
|
||||
stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
|
||||
CreateModuleConfig(
|
||||
const xla::ProgramShape& program_shape,
|
||||
absl::Span<const xla::Shape> argument_shapes,
|
||||
absl::optional<const xla::Shape> result_layout,
|
||||
absl::optional<const xla::DeviceAssignment> device_assignment,
|
||||
int replica_count,
|
||||
int num_partitions, const xla::DebugOptions* debug_options);
|
||||
|
||||
xla::ShapeTree<xla::HloSharding> GetSubtree(
|
||||
const xla::ShapeTree<xla::HloSharding>& tuple_shape_tree,
|
||||
int element_index);
|
||||
|
||||
xla::Shape GetPerDeviceShape(const xla::Shape& shape,
|
||||
const xla::HloSharding& sharding,
|
||||
int64 device);
|
||||
|
||||
Status AddVariableUpdatesToCores(
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
const std::vector<ShardingAndIndex>& arg_core_mapping,
|
||||
std::vector<bool>* may_modify_variables,
|
||||
std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
|
||||
std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices);
|
||||
|
||||
se::port::Status ComputeOutputShapesForEachCore(
|
||||
const tpu::TPUCompileMetadataProto& metadata,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
std::vector<std::vector<xla::Shape>>* per_core_output_shapes);
|
||||
|
||||
se::port::Status CreateHloModules(
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
const absl::optional<xla::DeviceAssignment>& device_assignment,
|
||||
std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules);
|
||||
|
||||
se::port::StatusOr<TpuAotCompilationRequestProto>
|
||||
CreateTpuAotCompilationRequest(
|
||||
const xla::HloModuleGroup& module_group,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
|
||||
const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
|
||||
const std::vector<std::vector<std::pair<int, bool>>>&
|
||||
per_core_variable_indices,
|
||||
const absl::optional<xla::DeviceAssignment>& device_assignment);
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
|
298
tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
Normal file
298
tensorflow/core/tpu/kernels/tpu_configuration_ops.cc
Normal file
@ -0,0 +1,298 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
|
||||
#include "tensorflow/core/tpu/tpu_config_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
|
||||
tpu::TpuMeshStateInterface** state) {
|
||||
if (!rmgr->Lookup(rmgr->default_container(),
|
||||
tpu::kTpuMeshCommonStateResourceName, state)
|
||||
.ok()) {
|
||||
return errors::FailedPrecondition(
|
||||
"The TPU system has not been initialized.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Attempt to delete resource_name from resource_manager's default_container.
|
||||
// Returns OK if the deletion succeeded, or if the resource was not found. Else
|
||||
// return the deletion error.
|
||||
template <class ResourceT>
|
||||
Status DeleteIfExists(ResourceMgr* resource_manager,
|
||||
const char* resource_name) {
|
||||
VLOG(1) << "Removing resource " << resource_name << " if it exists";
|
||||
Status status = resource_manager->Delete<ResourceT>(
|
||||
resource_manager->default_container(), resource_name);
|
||||
if (status.ok()) {
|
||||
VLOG(1) << "Removed existing resource " << resource_name;
|
||||
return Status::OK();
|
||||
}
|
||||
if (status.code() == error::NOT_FOUND) {
|
||||
VLOG(1) << "No resource " << resource_name << " to remove";
|
||||
return Status::OK();
|
||||
}
|
||||
VLOG(1) << "Error removing resource " << resource_name << " : " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "ConfigureDistributedTpuOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
|
||||
|
||||
std::vector<int32_t> num_devices_per_host;
|
||||
int chips_per_host = -1;
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& input_tensor = ctx->input(i);
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsScalar(input_tensor.shape()),
|
||||
errors::InvalidArgument("Input ", i, " should be a scalar but has ",
|
||||
input_tensor.dims(), " dimensions"));
|
||||
if (chips_per_host == -1) {
|
||||
chips_per_host = input_tensor.scalar<int32_t>()();
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
ctx, chips_per_host == input_tensor.scalar<int32>()(),
|
||||
errors::Internal("Host ", i, " has ", input_tensor.scalar<int32>()(),
|
||||
" TPU chips but host 0 has ", chips_per_host));
|
||||
}
|
||||
num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
size_t host_config_output_size;
|
||||
char* host_config_output;
|
||||
|
||||
auto* rmgr = GetTPUConfigResourceMgr();
|
||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||
rmgr, tpu::kTpuMeshCommonStateResourceName));
|
||||
|
||||
ConfigureDistributedTpuOp_DoWork(
|
||||
num_devices_per_host.size(), num_devices_per_host.data(),
|
||||
&host_config_output_size, &host_config_output, status);
|
||||
|
||||
OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
|
||||
tpu::kTpuMeshCommonStateResourceName,
|
||||
tpu::TpuMeshStateInterface::Create()));
|
||||
|
||||
Tensor* ctx_output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
||||
ctx_output->scalar<tstring>()() =
|
||||
std::string(host_config_output, host_config_output_size);
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
TpuConfigurationApi_FreeCharArray(host_config_output);
|
||||
|
||||
VLOG(1) << "ConfigureDistributedTpuOp done";
|
||||
}
|
||||
|
||||
void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "WaitForDistributedTpuOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp");
|
||||
|
||||
size_t num_devices_per_host = -1;
|
||||
size_t num_hosts = ctx->num_inputs();
|
||||
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
|
||||
OP_REQUIRES(
|
||||
ctx, host_ordinal_to_global_device_id_tensor.dims() == 1,
|
||||
errors::InvalidArgument("Input ", i, " should be a vector but has ",
|
||||
host_ordinal_to_global_device_id_tensor.dims(),
|
||||
" dimensions"));
|
||||
}
|
||||
|
||||
std::vector<std::vector<int32_t>> mapping;
|
||||
std::vector<int32_t*> mapping_arg;
|
||||
|
||||
mapping.resize(ctx->num_inputs());
|
||||
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
|
||||
const auto host_ordinal_to_global_device_id =
|
||||
host_ordinal_to_global_device_id_tensor.flat<int>();
|
||||
if (num_devices_per_host == -1) {
|
||||
num_devices_per_host =
|
||||
host_ordinal_to_global_device_id_tensor.dim_size(0);
|
||||
} else {
|
||||
OP_REQUIRES(ctx,
|
||||
num_devices_per_host ==
|
||||
host_ordinal_to_global_device_id_tensor.dim_size(0),
|
||||
errors::Internal(
|
||||
"Host ", i, " has ",
|
||||
host_ordinal_to_global_device_id_tensor.dim_size(0),
|
||||
" TPU devices but host 0 has ", num_devices_per_host));
|
||||
}
|
||||
for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0);
|
||||
++j) {
|
||||
int32_t global_device_id = host_ordinal_to_global_device_id(j);
|
||||
mapping[i].push_back(global_device_id);
|
||||
}
|
||||
mapping_arg.push_back(mapping[i].data());
|
||||
}
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
size_t tpu_topology_output_size;
|
||||
char* tpu_topology_output;
|
||||
|
||||
tpu::TpuMeshStateInterface* mesh_state;
|
||||
auto* rmgr = GetTPUConfigResourceMgr();
|
||||
OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
|
||||
core::ScopedUnref mesh_state_unref(mesh_state);
|
||||
|
||||
WaitForDistributedTpuOp_DoWork(
|
||||
num_hosts, num_devices_per_host,
|
||||
const_cast<const int32_t**>(mapping_arg.data()), mesh_state,
|
||||
&tpu_topology_output_size, &tpu_topology_output, status);
|
||||
|
||||
Tensor* ctx_output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
||||
ctx_output->scalar<tstring>()() =
|
||||
std::string(tpu_topology_output, tpu_topology_output_size);
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
TpuConfigurationApi_FreeCharArray(tpu_topology_output);
|
||||
|
||||
VLOG(1) << "WaitForDistributedTpuOp done";
|
||||
}
|
||||
|
||||
void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "ShutdownDistributedTpuOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
|
||||
GetTPUConfigResourceMgr(),
|
||||
tpu::kTpuMeshCommonStateResourceName));
|
||||
ShutdownDistributedTpuOp_DoWork(status);
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
VLOG(1) << "ShutdownDistributedTpuOp done";
|
||||
}
|
||||
|
||||
void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "InitializeHostForDistributedTpuOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
|
||||
|
||||
auto tpu_host_config = ctx->input(0).scalar<tstring>()();
|
||||
|
||||
size_t device_id_output_size;
|
||||
int32_t* device_id_output;
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
InitializeHostForDistributedTpuOp_DoWork(
|
||||
tpu_host_config.size(), tpu_host_config.data(),
|
||||
enable_whole_mesh_compilations_, &device_id_output_size,
|
||||
&device_id_output, status);
|
||||
|
||||
Tensor* ctx_output;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(
|
||||
0, TensorShape({static_cast<long long>(device_id_output_size)}),
|
||||
&ctx_output));
|
||||
|
||||
for (size_t i = 0; i < device_id_output_size; ++i) {
|
||||
ctx_output->flat<int32>()(i) = device_id_output[i];
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
TpuConfigurationApi_FreeInt32Array(device_id_output);
|
||||
|
||||
VLOG(1) << "InitializeHostForDistributedTpuOp done";
|
||||
}
|
||||
|
||||
void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "SetGlobalTPUArrayOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp");
|
||||
|
||||
auto tpu_topology = ctx->input(0).scalar<tstring>()();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
SetGlobalTPUArrayOp_DoWork(tpu_topology.size(), tpu_topology.data(), status);
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
VLOG(1) << "SetGlobalTPUArrayOp done";
|
||||
}
|
||||
|
||||
void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "DisconnectDistributedTpuChipsOp";
|
||||
XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp");
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
int32_t number_of_chips_output = 0;
|
||||
|
||||
DisconnectDistributedTpuChipsOp_DoWork(&number_of_chips_output, status);
|
||||
|
||||
Tensor* ctx_output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
|
||||
ctx_output->scalar<int32_t>()() = number_of_chips_output;
|
||||
|
||||
OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
VLOG(1) << "DisconnectDistributedTpuChipsOp done";
|
||||
}
|
||||
|
||||
// These ops execute on the TPU_SYSTEM device only.
|
||||
REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU")
|
||||
.Device(DEVICE_TPU_SYSTEM)
|
||||
.HostMemory("output"),
|
||||
ConfigureDistributedTpuOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU")
|
||||
.Device(DEVICE_TPU_SYSTEM)
|
||||
.HostMemory("inputs")
|
||||
.HostMemory("topology"),
|
||||
WaitForDistributedTpuOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM),
|
||||
ShutdownDistributedTpuOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU")
|
||||
.Device(DEVICE_TPU_SYSTEM)
|
||||
.HostMemory("input")
|
||||
.HostMemory("tpu_ids"),
|
||||
InitializeHostForDistributedTpuOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"),
|
||||
SetGlobalTPUArrayOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem")
|
||||
.Device(DEVICE_TPU_SYSTEM)
|
||||
.HostMemory("number_of_tpu_chips"),
|
||||
DisconnectDistributedTpuChipsOp);
|
||||
|
||||
} // namespace tensorflow
|
156
tensorflow/core/tpu/kernels/tpu_configuration_ops.h
Normal file
156
tensorflow/core/tpu/kernels/tpu_configuration_ops.h
Normal file
@ -0,0 +1,156 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The ConfigureDistributedTpu op is used to start an TPUDriver from
|
||||
// TensorFlow. It should be run on a TPU_SYSTEM device and returns the
|
||||
// connection host:port for the CompilationCacheServer. The
|
||||
// CompilationCacheServer will remain live until the device's Resource Manager
|
||||
// is cleared or a ShutdownDistributedTpuOp is run on the same device.
|
||||
class ConfigureDistributedTpuOp : public OpKernel {
|
||||
public:
|
||||
explicit ConfigureDistributedTpuOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->num_inputs() > 0,
|
||||
errors::Internal("_ConfigureDistributedTPU needs at least one input"));
|
||||
}
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
~ConfigureDistributedTpuOp() override {}
|
||||
|
||||
private:
|
||||
// ConfigureDistributedTpuOp is neither copyable nor movable.
|
||||
ConfigureDistributedTpuOp(const ConfigureDistributedTpuOp&) = delete;
|
||||
ConfigureDistributedTpuOp& operator=(const ConfigureDistributedTpuOp&) =
|
||||
delete;
|
||||
};
|
||||
|
||||
// The WaitForDistributedTpuOp op is used to block execution until
|
||||
// the distributed Tpu system has started up. It must be run on
|
||||
// the same TPU_SYSTEM device that ConfigureDistributedTpuOp was run
|
||||
// on, after all of the InitializeHostForDistributedTpuOp Ops have
|
||||
// completed.
|
||||
class WaitForDistributedTpuOp : public OpKernel {
|
||||
public:
|
||||
explicit WaitForDistributedTpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr("startup_timeout_sec", &startup_timeout_sec_));
|
||||
OP_REQUIRES(ctx, startup_timeout_sec_ > 0,
|
||||
errors::InvalidArgument("startup_timeout_sec ",
|
||||
startup_timeout_sec_, " must be >0"));
|
||||
}
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
~WaitForDistributedTpuOp() override {}
|
||||
|
||||
private:
|
||||
// The time to wait for all hosts to start up.
|
||||
int startup_timeout_sec_;
|
||||
|
||||
// WaitForDistributedTpuOp is neither copyable nor movable.
|
||||
WaitForDistributedTpuOp(const WaitForDistributedTpuOp&) = delete;
|
||||
WaitForDistributedTpuOp& operator=(const WaitForDistributedTpuOp&) = delete;
|
||||
};
|
||||
|
||||
// The ShutdownDistributedTpu op is used to stop a running TPUDriver from
|
||||
// TensorFlow. It should be run on the TPU_SYSTEM device where
|
||||
// ConfigureDistributedTpuOp was run.
|
||||
class ShutdownDistributedTpuOp : public OpKernel {
|
||||
public:
|
||||
explicit ShutdownDistributedTpuOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
~ShutdownDistributedTpuOp() override {}
|
||||
|
||||
private:
|
||||
// ShutdownDistributedTpuOp is neither copyable nor movable.
|
||||
ShutdownDistributedTpuOp(const ShutdownDistributedTpuOp&) = delete;
|
||||
ShutdownDistributedTpuOp& operator=(const ShutdownDistributedTpuOp&) = delete;
|
||||
};
|
||||
|
||||
// The InitializeHostForDistributedTpu op is used to initialize the
|
||||
// TPUPlatform on a host in a distributed TPU system. It should be
|
||||
// run on every host containing TPU devices before any other Ops that use
|
||||
// TPU are run.
|
||||
class InitializeHostForDistributedTpuOp : public OpKernel {
|
||||
public:
|
||||
explicit InitializeHostForDistributedTpuOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
ctx->GetAttr("enable_whole_mesh_compilations",
|
||||
&enable_whole_mesh_compilations_)
|
||||
.IgnoreError();
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
~InitializeHostForDistributedTpuOp() override {}
|
||||
|
||||
private:
|
||||
// InitializeHostForDistributedTpuOp is neither copyable nor movable.
|
||||
InitializeHostForDistributedTpuOp(const InitializeHostForDistributedTpuOp&) =
|
||||
delete;
|
||||
InitializeHostForDistributedTpuOp& operator=(
|
||||
const InitializeHostForDistributedTpuOp&) = delete;
|
||||
|
||||
bool enable_whole_mesh_compilations_ = false;
|
||||
};
|
||||
|
||||
// The SetGlobalTPUArray op is used to initialize the TPUPlatform on a
|
||||
// host in a distributed TPU system. It should be run on every host
|
||||
// containing TPU devices before any other Ops that use TPU are run.
|
||||
class SetGlobalTPUArrayOp : public OpKernel {
|
||||
public:
|
||||
explicit SetGlobalTPUArrayOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
~SetGlobalTPUArrayOp() override {}
|
||||
|
||||
private:
|
||||
// SetGlobalTPUArrayOp is neither copyable nor movable.
|
||||
SetGlobalTPUArrayOp(const SetGlobalTPUArrayOp&) = delete;
|
||||
SetGlobalTPUArrayOp& operator=(const SetGlobalTPUArrayOp&) = delete;
|
||||
};
|
||||
|
||||
// The DisconnectDistributedTpuChips op is used to disconnect all the chips on a
|
||||
// host from a running TPUDriver instance. It should be run on every host
|
||||
// containing TPU devices before the ShutdownDistributedTpuOp is run on
|
||||
// the TPU_SYSTEM.
|
||||
class DisconnectDistributedTpuChipsOp : public OpKernel {
|
||||
public:
|
||||
explicit DisconnectDistributedTpuChipsOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
~DisconnectDistributedTpuChipsOp() override {}
|
||||
|
||||
private:
|
||||
// DisconnectDistributedTpuChipsOp is neither copyable nor movable.
|
||||
DisconnectDistributedTpuChipsOp(const DisconnectDistributedTpuChipsOp&) =
|
||||
delete;
|
||||
DisconnectDistributedTpuChipsOp& operator=(
|
||||
const DisconnectDistributedTpuChipsOp&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_
|
94
tensorflow/core/tpu/kernels/tpu_executable_info.proto
Normal file
94
tensorflow/core/tpu/kernels/tpu_executable_info.proto
Normal file
@ -0,0 +1,94 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/compiler/xla/service/hlo.proto";
|
||||
import "tensorflow/compiler/xla/xla_data.proto";
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
|
||||
// A serialization of TPUExecutable. Only includes fields necessary to load
|
||||
// and execute a program on a worker node.
|
||||
message TPUExecutableInfoProto {
|
||||
reserved 1;
|
||||
|
||||
// The shapes of the inputs and outputs.
|
||||
repeated xla.ShapeProto input_shapes = 2;
|
||||
reserved 7; // was input_shape
|
||||
xla.ShapeProto output_shape = 3;
|
||||
|
||||
message UpdateIndexPair {
|
||||
int32 index = 1;
|
||||
bool updated = 2;
|
||||
}
|
||||
|
||||
message ShapeIndex {
|
||||
repeated int32 index = 1;
|
||||
}
|
||||
|
||||
// Dynamic output indices indicate which outputs have dynamic dimensions.
|
||||
repeated ShapeIndex dynamic_output_indices = 11;
|
||||
|
||||
// For each resource variable output, what was the index of the corresponding
|
||||
// input and was it updated? The indices are sorted by input order.
|
||||
repeated UpdateIndexPair variable_indices = 10;
|
||||
|
||||
// The shapes of the outputs when represented as Tensors. These may not
|
||||
// match the output_shape values because we may flatten tensors to avoid
|
||||
// excess padding.
|
||||
repeated TensorShapeProto output_tensor_shapes = 8;
|
||||
|
||||
reserved 4;
|
||||
|
||||
// Optional session module for passing XLA computations between TPUCompileOp
|
||||
// and TPUExecuteOp. This is needed to support the
|
||||
// --xla_dump_hlo_snapshots flag.
|
||||
xla.HloSnapshot session_module = 5;
|
||||
|
||||
// The physical device ids assigned to the replicated cores.
|
||||
xla.DeviceAssignmentProto device_assignment = 6;
|
||||
}
|
||||
|
||||
// Metadata for a data transfer between device and host.
|
||||
message TPUHostTransferProto {
|
||||
enum TransferDirection {
|
||||
NONE = 0;
|
||||
DEVICE_TO_HOST = 1;
|
||||
HOST_TO_DEVICE = 2;
|
||||
}
|
||||
// Channel identifier assigned by compiler and used in host commands.
|
||||
int64 channel = 1;
|
||||
// Direction of the transfer operation.
|
||||
TransferDirection direction = 2;
|
||||
// Channel identifier prodided by XLA client.
|
||||
string key = 3;
|
||||
// Depth of nested loops for this transfer operation.
|
||||
int64 nested_while_level = 4;
|
||||
// Shape of the data to be transferred (including layout).
|
||||
xla.ShapeProto shape = 5;
|
||||
// Address of the device buffer in HBM (byte offset).
|
||||
int64 buffer_offset = 6;
|
||||
// Original data type for this host transfer before X64 rewrite.
|
||||
xla.PrimitiveType original_type = 7;
|
||||
// If this host transfer is a splitted X64 transfer, sepcifies whether this
|
||||
// transfer is for lower bits.
|
||||
bool is_lower_bits = 8;
|
||||
}
|
||||
|
||||
message TPUHostTransferInfoProto {
|
||||
repeated TPUHostTransferProto host_transfers = 1;
|
||||
}
|
30
tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h
Normal file
30
tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
||||
|
||||
typedef struct XLA_TpuMeshState XLA_TpuMeshState;
|
||||
|
||||
// Creates a new TPU mesh state object.
|
||||
XLA_TpuMeshState* TpuMeshState_Create();
|
||||
|
||||
// Deletes the given TPU `mesh_state` object. Once deleted the object is
|
||||
// unusable.
|
||||
void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
|
||||
|
||||
// Returns a pointer to an opaque mesh data structure used internally.
|
||||
void* TpuMeshState_MeshCommonState(XLA_TpuMeshState* mesh_state);
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_
|
78
tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h
Normal file
78
tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h
Normal file
@ -0,0 +1,78 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
|
||||
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TpuMeshCommonState;
|
||||
|
||||
namespace tpu {
|
||||
|
||||
const char kTpuMeshCommonStateResourceName[] = "tpu_mesh_common_state";
|
||||
|
||||
class TpuMeshStateInterface : public tensorflow::ResourceBase {
|
||||
public:
|
||||
explicit TpuMeshStateInterface(XLA_TpuMeshState* handle)
|
||||
: mesh_state_(handle) {
|
||||
}
|
||||
|
||||
~TpuMeshStateInterface() override {
|
||||
if (mesh_state_ != nullptr) {
|
||||
TpuMeshState_Free(mesh_state_);
|
||||
}
|
||||
}
|
||||
|
||||
static TpuMeshStateInterface* Create() {
|
||||
return new TpuMeshStateInterface(TpuMeshState_Create());
|
||||
}
|
||||
|
||||
const XLA_TpuMeshState* data() const { return mesh_state_; }
|
||||
|
||||
tensorflow::TpuMeshCommonState* mesh_common_state() const {
|
||||
return static_cast<tensorflow::TpuMeshCommonState*>(
|
||||
TpuMeshState_MeshCommonState(mesh_state_));
|
||||
}
|
||||
|
||||
// Returns whether we should include the device assignment as a static field
|
||||
// to the TPU program. This also determines whether we should include the
|
||||
// device assignment as part of the compilation cache key.
|
||||
bool NeedsStaticDeviceAssignment(
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
TpuCoreTypeEnum tpu_core_type) const {
|
||||
// Static device assignment enables XLA to perform certain optimization when
|
||||
// all cores are used in the replicated computation.
|
||||
return metadata.num_cores_per_replica() * metadata.num_replicas() ==
|
||||
TpuTopology_AvailableCoreCount(mesh_state_,
|
||||
tpu_core_type);
|
||||
}
|
||||
|
||||
string DebugString() const override { return "TpuMeshStateInterface"; }
|
||||
|
||||
private:
|
||||
XLA_TpuMeshState* mesh_state_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_
|
201
tensorflow/core/tpu/kernels/tpu_program.cc
Normal file
201
tensorflow/core/tpu/kernels/tpu_program.cc
Normal file
@ -0,0 +1,201 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_group.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
namespace {
|
||||
|
||||
namespace se_tpu = ::stream_executor::tpu;
|
||||
|
||||
using stream_executor::port::StatusOr;
|
||||
using xla::Shape;
|
||||
|
||||
StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
|
||||
std::unique_ptr<xla::HloModuleGroup> module_group,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
|
||||
const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
|
||||
const std::vector<std::vector<std::pair<int, bool>>>&
|
||||
per_core_variable_indices,
|
||||
const absl::optional<xla::DeviceAssignment>& device_assignment) {
|
||||
VLOG(1) << "Run CompileAheadOfTime.";
|
||||
TF_ASSIGN_OR_RETURN(TpuAotCompilationRequestProto aot_request,
|
||||
CreateTpuAotCompilationRequest(
|
||||
*module_group, compilation_result, metadata,
|
||||
per_core_arg_shapes, per_core_output_shapes,
|
||||
per_core_variable_indices, device_assignment));
|
||||
se_tpu::SerializedProto serialized_aot_request =
|
||||
se_tpu::SerializeProto(aot_request);
|
||||
auto cleanup = gtl::MakeCleanup([serialized_aot_request] {
|
||||
se_tpu::SerializedProto_Free(serialized_aot_request);
|
||||
});
|
||||
|
||||
XLA_TpuProgram** xla_tpu_programs = nullptr;
|
||||
size_t count = 0;
|
||||
StatusHelper status;
|
||||
VLOG(1) << "Run TpuCompile_CompileAheadOfTime.";
|
||||
TpuCompile_CompileAheadOfTime(serialized_aot_request, &xla_tpu_programs,
|
||||
&count, status.c_status);
|
||||
VLOG(1) << "Run CompileAheadOfTime completed.";
|
||||
if (!status.status().ok()) {
|
||||
return status.status();
|
||||
}
|
||||
std::vector<XLA_TpuProgram*> tpu_programs(count, nullptr);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
tpu_programs[i] = xla_tpu_programs[i];
|
||||
}
|
||||
delete[] xla_tpu_programs;
|
||||
return tpu_programs;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::vector<XLA_TpuProgram*>> CompileAheadOfTime(
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const XlaCompiler::CompilationResult& compilation_result,
|
||||
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
|
||||
const std::vector<std::vector<xla::Shape>>& per_core_output_shapes,
|
||||
const std::vector<std::vector<std::pair<int, bool>>>&
|
||||
per_core_variable_indices,
|
||||
const absl::optional<xla::DeviceAssignment>& device_assignment) {
|
||||
VLOG(1) << "Compile Tpu programs.";
|
||||
std::vector<std::unique_ptr<xla::HloModule>> hlo_modules;
|
||||
auto status = CreateHloModules(metadata, compilation_result,
|
||||
device_assignment, &hlo_modules);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return CompileAheadOfTime(
|
||||
absl::make_unique<xla::HloModuleGroup>(hlo_modules[0]->name(),
|
||||
absl::MakeSpan(hlo_modules)),
|
||||
compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes,
|
||||
per_core_variable_indices, device_assignment);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int64_t TpuProgram::program_size() const {
|
||||
int64_t total_size = 0;
|
||||
for (XLA_TpuProgram* tpu_program : tpu_programs_) {
|
||||
total_size += TpuProgram_GetProgramSize(tpu_program);
|
||||
}
|
||||
return total_size;
|
||||
}
|
||||
|
||||
bool TpuProgram::LogProgramMemorySummary() {
|
||||
bool success = true;
|
||||
for (const XLA_TpuProgram* tpu_program : tpu_programs_) {
|
||||
success &= TpuProgram_LogProgramMemorySummary(tpu_program);
|
||||
}
|
||||
return success;
|
||||
}
|
||||
|
||||
void TpuProgram::UnloadAndDestroyPrograms() {
|
||||
for (XLA_TpuProgram* tpu_program : tpu_programs_) {
|
||||
StatusHelper status;
|
||||
TpuProgram_UnloadAndDestroy(tpu_program, status.c_status);
|
||||
auto s = status.status();
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "TpuProgram::UnloadPrograms(): " << s.ToString();
|
||||
}
|
||||
}
|
||||
tpu_programs_.clear();
|
||||
}
|
||||
|
||||
/*static*/ Status TpuProgram::Build(
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const tensorflow::XlaCompiler::CompilationResult& compilation_result,
|
||||
const std::vector<ShardingAndIndex>& arg_core_mapping,
|
||||
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
|
||||
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
|
||||
TpuProgram* tpu_program) {
|
||||
std::vector<std::vector<xla::Shape>> per_core_output_shapes(
|
||||
metadata.num_cores_per_replica());
|
||||
TF_RETURN_IF_ERROR(ComputeOutputShapesForEachCore(
|
||||
metadata, compilation_result, &per_core_output_shapes));
|
||||
|
||||
std::vector<std::vector<std::pair<int, bool>>> per_core_variable_indices(
|
||||
metadata.num_cores_per_replica());
|
||||
std::vector<bool> may_modify_variables;
|
||||
TF_RETURN_IF_ERROR(AddVariableUpdatesToCores(
|
||||
metadata, compilation_result, arg_core_mapping, &may_modify_variables,
|
||||
&per_core_output_shapes, &per_core_variable_indices));
|
||||
TF_RET_CHECK(per_core_arg_shapes.size() == metadata.num_cores_per_replica());
|
||||
TF_RET_CHECK(per_core_output_shapes.size() == per_core_arg_shapes.size());
|
||||
TF_RET_CHECK(per_core_output_shapes.size() ==
|
||||
per_core_variable_indices.size());
|
||||
tpu_program->set_may_modify_variables(may_modify_variables);
|
||||
|
||||
// With shardable input/output pairs, XLA could generate separate
|
||||
// sharding/unsharding programs along with the main program. The
|
||||
// sharding/unsharding programs will be in nested entries of the AOT
|
||||
// compilation result.
|
||||
auto status_or = CompileAheadOfTime(
|
||||
metadata, compilation_result, per_core_arg_shapes, per_core_output_shapes,
|
||||
per_core_variable_indices, xla_device_assignment);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::vector<XLA_TpuProgram*> xla_tpu_programs,
|
||||
std::move(status_or));
|
||||
// SPMD could return 1 result for all partitions.
|
||||
TF_RET_CHECK(xla_tpu_programs.size() == 1 ||
|
||||
xla_tpu_programs.size() == metadata.num_cores_per_replica());
|
||||
tpu_program->set_tpu_programs(xla_tpu_programs);
|
||||
|
||||
// TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1.
|
||||
TpuSerializedProto serialized_executable_info;
|
||||
TpuProgram_GetExecutableInfo(xla_tpu_programs[0],
|
||||
&serialized_executable_info);
|
||||
TPUExecutableInfoProto executable_info =
|
||||
se_tpu::DeserializeProto<TPUExecutableInfoProto>(
|
||||
serialized_executable_info);
|
||||
tpu_program->set_executable_info(executable_info);
|
||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info);
|
||||
|
||||
TPUHostTransferInfoProto host_transfer_info;
|
||||
TpuSerializedProto serialized_host_transfer_info;
|
||||
TpuProgram_GetHostTransferInfo(xla_tpu_programs[0],
|
||||
&serialized_host_transfer_info);
|
||||
if (serialized_host_transfer_info.size > 0) {
|
||||
host_transfer_info = se_tpu::DeserializeProto<TPUHostTransferInfoProto>(
|
||||
serialized_host_transfer_info);
|
||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info);
|
||||
}
|
||||
tpu_program->set_host_transfer_info(host_transfer_info);
|
||||
|
||||
TpuSerializedProto serialized_hlo_metadata;
|
||||
TpuProgram_GetHloMetadata(xla_tpu_programs[0], &serialized_hlo_metadata);
|
||||
xla::HloProto hlo_metadata =
|
||||
se_tpu::DeserializeProto<xla::HloProto>(serialized_hlo_metadata);
|
||||
tpu_program->set_hlo_metadata(hlo_metadata);
|
||||
StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
161
tensorflow/core/tpu/kernels/tpu_program.h
Normal file
161
tensorflow/core/tpu/kernels/tpu_program.h
Normal file
@ -0,0 +1,161 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
|
||||
#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuAotCompilationOptions : public xla::AotCompilationOptions {
|
||||
public:
|
||||
explicit TpuAotCompilationOptions(int64 replica_count)
|
||||
: num_cores_(0), replica_count_(replica_count) {}
|
||||
|
||||
// Returns the ID of the platform to which these options apply.
|
||||
se::Platform::Id PlatformId() const override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
void set_num_cores(int64 tpu_cores) { num_cores_ = tpu_cores; }
|
||||
int64 replica_count() const override { return replica_count_; }
|
||||
int64 num_cores() const override { return num_cores_; }
|
||||
|
||||
void set_allow_separate_sharding_programs(bool allow) {
|
||||
allow_separate_sharding_programs_ = allow;
|
||||
}
|
||||
bool allow_separate_sharding_programs() const {
|
||||
return allow_separate_sharding_programs_;
|
||||
}
|
||||
|
||||
const std::vector<xla::HloModuleConfig::ShardableValueUpdatePair>
|
||||
shardable_value_update_pairs() const {
|
||||
return shardable_value_update_pairs_;
|
||||
}
|
||||
void set_shardable_value_update_pairs(
|
||||
std::vector<xla::HloModuleConfig::ShardableValueUpdatePair> pairs) {
|
||||
shardable_value_update_pairs_ = std::move(pairs);
|
||||
}
|
||||
|
||||
private:
|
||||
int64 num_cores_;
|
||||
int64 replica_count_;
|
||||
|
||||
// Whether to allow the compiler to create separte sharding and unsharding
|
||||
// programs, and modify the original program's input/output sharded size. This
|
||||
// is used for XLA-chosen sharding on parameters without an on-device loop:
|
||||
// the caller can invoke sharding first, then (repeatedly) invoke the sharded
|
||||
// main program, and finally invoke the unsharding program when it needs the
|
||||
// full output.
|
||||
bool allow_separate_sharding_programs_ = false;
|
||||
|
||||
// The list of input/output pairs in the main program that could be sharded.
|
||||
std::vector<xla::HloModuleConfig::ShardableValueUpdatePair>
|
||||
shardable_value_update_pairs_;
|
||||
};
|
||||
|
||||
// An executable capable of being fed to a TPU device.
|
||||
class TpuProgram {
|
||||
public:
|
||||
using Status = ::stream_executor::port::Status;
|
||||
|
||||
virtual ~TpuProgram() = default;
|
||||
|
||||
static Status Build(
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const tensorflow::XlaCompiler::CompilationResult& compilation_result,
|
||||
const std::vector<ShardingAndIndex>& arg_core_mapping,
|
||||
const std::vector<std::vector<xla::Shape>>& per_core_arg_shapes,
|
||||
const absl::optional<xla::DeviceAssignment>& xla_device_assignment,
|
||||
TpuProgram* tpu_program);
|
||||
|
||||
size_t program_count() const {
|
||||
return tpu_programs_.size();
|
||||
}
|
||||
|
||||
int64_t program_size() const;
|
||||
|
||||
bool LogProgramMemorySummary();
|
||||
|
||||
void UnloadAndDestroyPrograms();
|
||||
|
||||
const std::vector<bool>& may_modify_variables() const {
|
||||
return may_modify_variables_;
|
||||
}
|
||||
void set_may_modify_variables(const std::vector<bool>& may_modify_variables) {
|
||||
may_modify_variables_ = may_modify_variables;
|
||||
}
|
||||
|
||||
const tf2xla::HostComputeMetadata& host_compute_metadata() const {
|
||||
return host_compute_metadata_;
|
||||
}
|
||||
void set_host_compute_metadata(
|
||||
const tf2xla::HostComputeMetadata& host_compute_metadata) {
|
||||
host_compute_metadata_ = host_compute_metadata;
|
||||
}
|
||||
|
||||
const std::vector<XLA_TpuProgram*>& tpu_programs() const {
|
||||
return tpu_programs_;
|
||||
}
|
||||
void set_tpu_programs(std::vector<XLA_TpuProgram*> tpu_programs) {
|
||||
tpu_programs_ = tpu_programs;
|
||||
}
|
||||
|
||||
const TPUExecutableInfoProto& executable_info() const {
|
||||
return executable_info_;
|
||||
}
|
||||
void set_executable_info(const TPUExecutableInfoProto& executable_info) {
|
||||
executable_info_ = executable_info;
|
||||
}
|
||||
|
||||
const TPUHostTransferInfoProto& host_transfer_info() const {
|
||||
return host_transfer_info_;
|
||||
}
|
||||
void set_host_transfer_info(
|
||||
const TPUHostTransferInfoProto& host_transfer_info) {
|
||||
host_transfer_info_ = host_transfer_info;
|
||||
}
|
||||
|
||||
const xla::HloProto& hlo_metadata() const { return hlo_metadata_; }
|
||||
void set_hlo_metadata(const xla::HloProto& hlo_metadata) {
|
||||
hlo_metadata_ = hlo_metadata;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<bool> may_modify_variables_;
|
||||
tf2xla::HostComputeMetadata host_compute_metadata_;
|
||||
|
||||
std::vector<XLA_TpuProgram*> tpu_programs_; // Not owned.
|
||||
TPUExecutableInfoProto executable_info_;
|
||||
TPUHostTransferInfoProto host_transfer_info_;
|
||||
xla::HloProto hlo_metadata_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_
|
100
tensorflow/core/tpu/kernels/tpu_util.cc
Normal file
100
tensorflow/core/tpu/kernels/tpu_util.cc
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/core/platform/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
std::string SessionNameFromMetadata(const SessionMetadata* session_metadata) {
|
||||
return session_metadata ? session_metadata->name() : "";
|
||||
}
|
||||
|
||||
std::string ProtoKeyForComputation(const std::string& key, int core) {
|
||||
return absl::StrCat(key, ":", core);
|
||||
}
|
||||
|
||||
xla::StatusOr<TpuCompilationCacheKey> ParseCompilationCacheKey(
|
||||
const std::string& key) {
|
||||
const std::vector<std::string> splits = absl::StrSplit(key, '|');
|
||||
if (splits.size() == 1) {
|
||||
// No guaranteed_const.
|
||||
return TpuCompilationCacheKey(key);
|
||||
} else if (splits.size() != 3) {
|
||||
return errors::InvalidArgument("Invalid TPU compilation cache key:", key);
|
||||
}
|
||||
|
||||
TpuCompilationCacheKey parsed_key(splits.at(0));
|
||||
parsed_key.has_guaranteed_const = true;
|
||||
parsed_key.session_handle = splits.at(1);
|
||||
const string fingerprint = splits.at(2);
|
||||
parsed_key.guaranteed_const_fingerprint = [fingerprint] {
|
||||
return fingerprint;
|
||||
};
|
||||
return parsed_key;
|
||||
}
|
||||
|
||||
xla::CompileOnlyClient::AotXlaComputationInstance
|
||||
BuildAotXlaComputationInstance(
|
||||
const XlaCompiler::CompilationResult& compilation_result) {
|
||||
xla::CompileOnlyClient::AotXlaComputationInstance instance;
|
||||
instance.computation = compilation_result.computation.get();
|
||||
for (const xla::Shape& shape : compilation_result.xla_input_shapes) {
|
||||
instance.argument_layouts.push_back(&shape);
|
||||
}
|
||||
instance.result_layout = &compilation_result.xla_output_shape;
|
||||
return instance;
|
||||
}
|
||||
|
||||
Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape) {
|
||||
if (tensor.dtype() != DT_INT64 ||
|
||||
!TensorShapeUtils::IsVector(tensor.shape())) {
|
||||
return errors::InvalidArgument("Shape tensor must be an int64 vector.");
|
||||
}
|
||||
const int64 rank = tensor.NumElements();
|
||||
auto tensor_dims = tensor.flat<int64>();
|
||||
std::vector<int64> dims(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
dims[i] = tensor_dims(i);
|
||||
}
|
||||
return TensorShapeUtils::MakeShape(dims, shape);
|
||||
}
|
||||
|
||||
Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
|
||||
std::vector<TensorShape>* shapes) {
|
||||
shapes->resize(dynamic_shapes.size());
|
||||
for (int i = 0; i < dynamic_shapes.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i]));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
|
||||
std::vector<TensorShape>* shapes) {
|
||||
shapes->resize(dynamic_shapes.end() - dynamic_shapes.begin());
|
||||
size_t i = 0;
|
||||
for (auto& dynamic_shape : dynamic_shapes) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeTensorToTensorShape(dynamic_shape.tensor(), &(*shapes)[i]));
|
||||
++i;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
67
tensorflow/core/tpu/kernels/tpu_util.h
Normal file
67
tensorflow/core/tpu/kernels/tpu_util.h
Normal file
@ -0,0 +1,67 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// Utility to get session_name from `SessionMetadata`. `SessionMetadata` may
|
||||
// be null.
|
||||
std::string SessionNameFromMetadata(const SessionMetadata* session_metadata);
|
||||
|
||||
// Generates cache proto key for a given computation on a TPU core.
|
||||
std::string ProtoKeyForComputation(const std::string& key, int core);
|
||||
|
||||
// Returns a TpuCompilationCacheKey parsed from given key or an error.
|
||||
xla::StatusOr<TpuCompilationCacheKey> ParseCompilationCacheKey(
|
||||
const std::string& key);
|
||||
|
||||
xla::CompileOnlyClient::AotXlaComputationInstance
|
||||
BuildAotXlaComputationInstance(
|
||||
const XlaCompiler::CompilationResult& compilation_result);
|
||||
|
||||
// Returns true if TPU compilation is enabled.
|
||||
bool IsTpuCompilationEnabled();
|
||||
|
||||
// Converts an int64 host memory `tensor` to a `shape`.
|
||||
Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape);
|
||||
|
||||
Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes,
|
||||
std::vector<TensorShape>* shapes);
|
||||
Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
|
||||
std::vector<TensorShape>* shapes);
|
||||
|
||||
// Given a tensor of `shape` and `type`, as what shape should it be stored on
|
||||
// the TPU device? This function tranposes or flattens the excessively-padded
|
||||
// tensors to rank 1, but leaves other tensor shapes alone.
|
||||
xla::StatusOr<xla::Shape> TpuShapeRepresentation(const TensorShape& shape,
|
||||
DataType type,
|
||||
bool use_fast_memory);
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_
|
27
tensorflow/core/tpu/kernels/trace_util.h
Normal file
27
tensorflow/core/tpu/kernels/trace_util.h
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
#include "base/tracer.h"
|
||||
#else
|
||||
#undef TRACESTRING
|
||||
#define TRACESTRING(x)
|
||||
#undef TRACELITERAL
|
||||
#define TRACELITERAL(x)
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_
|
@ -64,13 +64,20 @@ TfTpu_ConfigApiFn* ConfigApiFn() {
|
||||
}
|
||||
|
||||
Status InitializeTpuLibrary(void* library_handle) {
|
||||
bool shared_object_loaded = true;
|
||||
if (library_handle == nullptr) {
|
||||
library_handle = dlopen(nullptr, RTLD_LAZY);
|
||||
shared_object_loaded = false;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle));
|
||||
TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle));
|
||||
|
||||
if (shared_object_loaded) {
|
||||
// Initialize TPU platform when the platform code is loaded from a library.
|
||||
InitializeApiFn()->TfTpu_InitializeFn();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
234
tensorflow/stream_executor/tpu/BUILD
Normal file
234
tensorflow/stream_executor/tpu/BUILD
Normal file
@ -0,0 +1,234 @@
|
||||
# Description: StreamExecutor Interface for TPUs
|
||||
|
||||
package(
|
||||
default_visibility = ["//tensorflow/core/tpu:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_executor_c_api_hdrs",
|
||||
hdrs = ["tpu_executor_c_api.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_attrtype",
|
||||
"//tensorflow/c:tf_datatype",
|
||||
"//tensorflow/c:tf_status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_node_context_c_api_hdrs",
|
||||
hdrs = ["tpu_node_context_c_api.h"],
|
||||
deps = [
|
||||
":tpu_executor_c_api_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "status_helper",
|
||||
hdrs = ["status_helper.h"],
|
||||
deps = [
|
||||
":tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "c_api_conversions",
|
||||
hdrs = ["c_api_conversions.h"],
|
||||
deps = [
|
||||
":tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "proto_helper",
|
||||
srcs = ["proto_helper.cc"],
|
||||
hdrs = ["proto_helper.h"],
|
||||
deps = ["//tensorflow/core:lib"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_stream",
|
||||
hdrs = ["tpu_stream.h"],
|
||||
deps = [
|
||||
":tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_timer",
|
||||
hdrs = ["tpu_timer.h"],
|
||||
deps = [
|
||||
":tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_executor",
|
||||
srcs = ["tpu_executor.cc"],
|
||||
hdrs = ["tpu_executor.h"],
|
||||
deps = [
|
||||
":c_api_conversions",
|
||||
":status_helper",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_executor_interface",
|
||||
":tpu_platform",
|
||||
":tpu_platform_interface",
|
||||
":tpu_stream",
|
||||
":tpu_timer",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_executor_hdrs",
|
||||
hdrs = ["tpu_executor.h"],
|
||||
deps = [
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_executor_interface",
|
||||
":tpu_platform_hdrs",
|
||||
":tpu_platform_interface",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor:stream_header",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_platform_hdrs",
|
||||
hdrs = ["tpu_platform.h"],
|
||||
deps = [
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_platform_interface",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor:stream_header",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_node_context",
|
||||
srcs = ["tpu_node_context.cc"],
|
||||
hdrs = ["tpu_node_context.h"],
|
||||
deps = [
|
||||
":status_helper",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_node_context_c_api_hdrs",
|
||||
":tpu_platform_interface",
|
||||
":tpu_transfer_manager",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xla/service:stream_pool",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_platform",
|
||||
srcs = ["tpu_platform.cc"],
|
||||
hdrs = ["tpu_platform.h"],
|
||||
deps = [
|
||||
":status_helper",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_executor_hdrs",
|
||||
":tpu_platform_interface",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_transfer_manager",
|
||||
srcs = ["tpu_transfer_manager_registration.cc"],
|
||||
deps = [
|
||||
":tpu_platform",
|
||||
":tpu_transfer_manager_base",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_transfer_manager_base",
|
||||
srcs = ["tpu_transfer_manager.cc"],
|
||||
hdrs = ["tpu_transfer_manager.h"],
|
||||
deps = [
|
||||
":c_api_conversions",
|
||||
":proto_helper",
|
||||
":status_helper",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_platform",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_computation_placer",
|
||||
srcs = ["tpu_computation_placer.cc"],
|
||||
hdrs = ["tpu_computation_placer.h"],
|
||||
deps = [
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_platform",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_platform_interface",
|
||||
srcs = ["tpu_platform_interface.cc"],
|
||||
hdrs = ["tpu_platform_interface.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor:multi_platform_manager",
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_stream_interface",
|
||||
hdrs = ["tpu_stream_interface.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//tensorflow/stream_executor:stream_executor_internal"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_executor_interface",
|
||||
hdrs = ["tpu_executor_interface.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tpu_platform_interface",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/stream_executor:stream_executor_internal",
|
||||
"//tensorflow/stream_executor:stream_header",
|
||||
],
|
||||
)
|
115
tensorflow/stream_executor/tpu/c_api_conversions.h
Normal file
115
tensorflow/stream_executor/tpu/c_api_conversions.h
Normal file
@ -0,0 +1,115 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
class TpuConversions {
|
||||
public:
|
||||
static stream_executor::DeviceMemoryBase
|
||||
SE_DeviceMemoryBaseToDeviceMemoryBase(SE_DeviceMemoryBase se_base) {
|
||||
stream_executor::DeviceMemoryBase base(se_base.opaque, se_base.size);
|
||||
base.SetPayload(se_base.payload);
|
||||
return base;
|
||||
}
|
||||
|
||||
static SE_DeviceMemoryBase DeviceMemoryBaseToSE_DeviceMemoryBase(
|
||||
const stream_executor::DeviceMemoryBase& base) {
|
||||
SE_DeviceMemoryBase se_base;
|
||||
se_base.opaque = const_cast<void*>(base.opaque());
|
||||
se_base.payload = base.payload();
|
||||
se_base.size = base.size();
|
||||
return se_base;
|
||||
}
|
||||
|
||||
static xla::Shape CShapeToXlaShape(XLA_Shape* shape) {
|
||||
xla::ShapeProto p;
|
||||
p.ParseFromArray(shape->bytes, shape->size);
|
||||
return xla::Shape(p);
|
||||
}
|
||||
|
||||
static void XlaShapeToCShape(const xla::Shape& xla_shape,
|
||||
XLA_Shape* c_shape) {
|
||||
xla::ShapeProto p = xla_shape.ToProto();
|
||||
std::string p_str = p.SerializeAsString();
|
||||
c_shape->bytes = new char[p_str.size()];
|
||||
c_shape->size = p_str.size();
|
||||
memcpy(c_shape->bytes, p_str.data(), p_str.size());
|
||||
}
|
||||
|
||||
static void XLAShapedBufferToCShapedBuffer(
|
||||
const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) {
|
||||
XlaShapeToCShape(buffer.on_host_shape(), &c_device_buffer->on_host_shape);
|
||||
XlaShapeToCShape(buffer.on_device_shape(),
|
||||
&c_device_buffer->on_device_shape);
|
||||
c_device_buffer->device_ordinal = buffer.device_ordinal();
|
||||
absl::InlinedVector<SE_DeviceMemoryBase, 2> bases;
|
||||
for (auto& pair : buffer.buffers()) {
|
||||
bases.push_back(DeviceMemoryBaseToSE_DeviceMemoryBase(pair.second));
|
||||
}
|
||||
c_device_buffer->count = bases.size();
|
||||
c_device_buffer->bases = new SE_DeviceMemoryBase[bases.size()];
|
||||
for (int i = 0; i < bases.size(); ++i) {
|
||||
c_device_buffer->bases[i] = bases[i];
|
||||
}
|
||||
}
|
||||
|
||||
static void XLALiteralToCLiteral(const xla::LiteralSlice& literal,
|
||||
XLA_Literal* c_literal) {
|
||||
XlaShapeToCShape(literal.shape(), &c_literal->shape);
|
||||
auto shapes = xla::ShapeUtil::GetLeafShapes(literal.shape());
|
||||
c_literal->buffers = new char*[shapes.size()];
|
||||
c_literal->sizes = new size_t[shapes.size()];
|
||||
c_literal->count = shapes.size();
|
||||
for (int i = 0; i < shapes.size(); ++i) {
|
||||
c_literal->buffers[i] = reinterpret_cast<char*>(
|
||||
const_cast<void*>(literal.untyped_data(shapes[i].index)));
|
||||
c_literal->sizes[i] = literal.size_bytes(shapes[i].index);
|
||||
}
|
||||
}
|
||||
|
||||
static xla::MutableBorrowingLiteral CLiteralToXLALiteral(
|
||||
XLA_Literal* c_literal) {
|
||||
xla::Shape shape = CShapeToXlaShape(&c_literal->shape);
|
||||
LOG(INFO) << "Shape: " << shape.DebugString();
|
||||
return xla::MutableBorrowingLiteral(
|
||||
absl::MakeSpan(c_literal->buffers, c_literal->count), shape);
|
||||
}
|
||||
|
||||
static void CShapeCleanup(XLA_Shape* c_shape) { delete[] c_shape->bytes; }
|
||||
|
||||
static void CLiteralCleanup(XLA_Literal* c_literal) {
|
||||
delete[] c_literal->buffers;
|
||||
delete[] c_literal->sizes;
|
||||
CShapeCleanup(&c_literal->shape);
|
||||
}
|
||||
|
||||
static void CShapedBufferCleanup(XLA_ShapedBuffer* c_buffer) {
|
||||
CShapeCleanup(&c_buffer->on_device_shape);
|
||||
CShapeCleanup(&c_buffer->on_host_shape);
|
||||
delete[] c_buffer->bases;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
|
27
tensorflow/stream_executor/tpu/proto_helper.cc
Normal file
27
tensorflow/stream_executor/tpu/proto_helper.cc
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
void StreamExecutor_Tpu_FreeSerializedProto(const TpuSerializedProto* proto) {
|
||||
CHECK_NE(proto, nullptr);
|
||||
CHECK_NE(proto->bytes, nullptr);
|
||||
CHECK_GT(proto->size, 0);
|
||||
delete[] proto->bytes;
|
||||
}
|
||||
|
||||
} // extern "C"
|
85
tensorflow/stream_executor/tpu/proto_helper.h
Normal file
85
tensorflow/stream_executor/tpu/proto_helper.h
Normal file
@ -0,0 +1,85 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
typedef struct TpuSerializedProto {
|
||||
const char* bytes;
|
||||
size_t size;
|
||||
} TpuSerializedProto;
|
||||
|
||||
void StreamExecutor_Tpu_FreeSerializedProto(const TpuSerializedProto* proto);
|
||||
|
||||
} // extern "C"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace tpu {
|
||||
|
||||
using SerializedProto = TpuSerializedProto;
|
||||
|
||||
// Serializes a proto and put the result in the given SerializedProto* argument.
|
||||
//
|
||||
// Users should call SerializedProto_Free on `serialized_proto` afterwards.
|
||||
template <class Proto>
|
||||
inline void SerializeProto(const Proto& proto,
|
||||
SerializedProto* serialized_proto) {
|
||||
auto size = proto.ByteSizeLong();
|
||||
auto bytes = new char[size];
|
||||
CHECK(proto.SerializeToArray(bytes, size));
|
||||
serialized_proto->size = size;
|
||||
serialized_proto->bytes = bytes;
|
||||
}
|
||||
|
||||
// Serializes a proto and return the result as a SerializedProto value.
|
||||
//
|
||||
// Users should call SerializedProto_Free on the return value afterwards.
|
||||
template <class Proto>
|
||||
inline SerializedProto SerializeProto(const Proto& proto) {
|
||||
SerializedProto serialized_proto;
|
||||
SerializeProto(proto, &serialized_proto);
|
||||
return serialized_proto;
|
||||
}
|
||||
|
||||
// Deserializes a buffer and return the corresponding proto. If the buffer is
|
||||
// empty, return an empty proto.
|
||||
template <class Proto>
|
||||
inline Proto DeserializeProto(const SerializedProto& serialized_proto) {
|
||||
Proto proto;
|
||||
if (serialized_proto.bytes != nullptr) {
|
||||
CHECK_GT(serialized_proto.size, 0);
|
||||
CHECK(proto.ParseFromArray(serialized_proto.bytes, serialized_proto.size))
|
||||
<< "Invalid buffer, failed to deserialize buffer.";
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
// Releases the memory allocated for serialized protos.
|
||||
inline void SerializedProto_Free(const SerializedProto& serialized_proto) {
|
||||
CHECK_NE(serialized_proto.bytes, nullptr);
|
||||
CHECK_GT(serialized_proto.size, 0);
|
||||
delete[] serialized_proto.bytes;
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_
|
38
tensorflow/stream_executor/tpu/status_helper.h
Normal file
38
tensorflow/stream_executor/tpu/status_helper.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
|
||||
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
struct StatusHelper {
|
||||
StatusHelper() : c_status(TpuStatus_New()) {}
|
||||
~StatusHelper() { TpuStatus_Free(c_status); }
|
||||
bool ok() { return TpuStatus_Code(c_status) == 0; }
|
||||
tensorflow::Status status() {
|
||||
if (!ok()) {
|
||||
return tensorflow::Status(
|
||||
tensorflow::error::Code(TpuStatus_Code(c_status)),
|
||||
TpuStatus_Message(c_status));
|
||||
} else {
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
}
|
||||
SE_Status* c_status;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_
|
51
tensorflow/stream_executor/tpu/tpu_computation_placer.cc
Normal file
51
tensorflow/stream_executor/tpu/tpu_computation_placer.cc
Normal file
@ -0,0 +1,51 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h"
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
|
||||
template <typename T>
|
||||
using StatusOr = TpuComputationPlacer::StatusOr<T>;
|
||||
|
||||
TpuComputationPlacer::TpuComputationPlacer() {
|
||||
placer_ = TpuComputationPlacer_New();
|
||||
}
|
||||
|
||||
TpuComputationPlacer::~TpuComputationPlacer() {
|
||||
TpuComputationPlacer_Free(placer_);
|
||||
}
|
||||
|
||||
StatusOr<int> TpuComputationPlacer::DeviceId(int replica, int computation,
|
||||
int replica_count,
|
||||
int computation_count) {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
}
|
||||
|
||||
StatusOr<xla::DeviceAssignment> TpuComputationPlacer::AssignDevices(
|
||||
int replica_count, int computation_count) {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
}
|
||||
|
||||
static std::unique_ptr<xla::ComputationPlacer> CreateTpuComputationPlacer() {
|
||||
return std::make_unique<TpuComputationPlacer>();
|
||||
}
|
||||
|
||||
static bool InitModule() {
|
||||
xla::ComputationPlacer::RegisterComputationPlacer(
|
||||
tensorflow::TpuPlatform::kId, CreateTpuComputationPlacer);
|
||||
return true;
|
||||
}
|
||||
static bool module_initialized = InitModule();
|
41
tensorflow/stream_executor/tpu/tpu_computation_placer.h
Normal file
41
tensorflow/stream_executor/tpu/tpu_computation_placer.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
class TpuComputationPlacer : public xla::ComputationPlacer {
|
||||
public:
|
||||
template <typename T>
|
||||
using StatusOr = xla::StatusOr<T>;
|
||||
|
||||
TpuComputationPlacer();
|
||||
~TpuComputationPlacer() override;
|
||||
|
||||
StatusOr<int> DeviceId(int replica, int computation, int replica_count,
|
||||
int computation_count) override;
|
||||
|
||||
StatusOr<xla::DeviceAssignment> AssignDevices(int replica_count,
|
||||
int computation_count) override;
|
||||
|
||||
private:
|
||||
XLA_ComputationPlacer* placer_;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_
|
355
tensorflow/stream_executor/tpu/tpu_executor.cc
Normal file
355
tensorflow/stream_executor/tpu/tpu_executor.cc
Normal file
@ -0,0 +1,355 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_timer.h"
|
||||
|
||||
using stream_executor::DeviceMemoryBase;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
using ::stream_executor::port::Status;
|
||||
} // namespace
|
||||
|
||||
TpuExecutor::~TpuExecutor() { TpuExecutor_Free(executor_); }
|
||||
|
||||
Status TpuExecutor::Init(int device_ordinal,
|
||||
::stream_executor::DeviceOptions device_options) {
|
||||
StatusHelper status;
|
||||
SE_DeviceOptions* options =
|
||||
TpuExecutor_NewDeviceOptions(device_options.flags());
|
||||
TpuExecutor_Init(executor_, device_ordinal, options, status.c_status);
|
||||
TpuExecutor_FreeDeviceOptions(options);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
int TpuExecutor::PlatformDeviceCount() {
|
||||
return TpuExecutor_PlatformDeviceCount(executor_);
|
||||
}
|
||||
|
||||
void TpuExecutor::SyncAndForgetFailedStreams() {
|
||||
TpuExecutor_SyncAndForgetFailedStreams(executor_);
|
||||
}
|
||||
|
||||
bool TpuExecutor::SynchronizeAllActivity() {
|
||||
return TpuExecutor_SynchronizeAllActivity(executor_);
|
||||
}
|
||||
|
||||
Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_BlockHostUntilDone(
|
||||
executor_, stream_map().at(stream->implementation()), status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuExecutor::BlockUntilDoneOrFailed() {
|
||||
StatusHelper status;
|
||||
TpuExecutor_BlockUntilDoneOrFailed(executor_, status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuExecutor::GetStatus(Stream* stream) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_GetStatus(executor_, stream_map().at(stream->implementation()),
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
bool TpuExecutor::AllocateStream(Stream* stream) {
|
||||
return TpuExecutor_AllocateStream(executor_,
|
||||
stream_map().at(stream->implementation()));
|
||||
}
|
||||
|
||||
void TpuExecutor::DeallocateStream(Stream* stream) {
|
||||
TpuExecutor_DeallocateStream(executor_,
|
||||
stream_map().at(stream->implementation()));
|
||||
stream_map().erase(stream->implementation());
|
||||
}
|
||||
|
||||
bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
|
||||
return TpuExecutor_CreateStreamDependency(
|
||||
executor_, stream_map().at(dependent->implementation()),
|
||||
stream_map().at(other->implementation()));
|
||||
}
|
||||
|
||||
Status TpuExecutor::AllocateEvent(Event* event) { return Status::OK(); }
|
||||
|
||||
Status TpuExecutor::DeallocateEvent(Event* event) { return Status::OK(); }
|
||||
|
||||
// AllocateTimer/DeallocateTimer have no specialization.
|
||||
bool TpuExecutor::AllocateTimer(Timer* timer) { return true; }
|
||||
|
||||
void TpuExecutor::DeallocateTimer(Timer* timer) {}
|
||||
|
||||
bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
|
||||
return TpuExecutor_StartTimer(executor_,
|
||||
stream_map().at(stream->implementation()),
|
||||
timer_map_.at(timer->implementation()));
|
||||
}
|
||||
|
||||
bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
|
||||
return TpuExecutor_StopTimer(executor_,
|
||||
stream_map().at(stream->implementation()),
|
||||
timer_map_.at(timer->implementation()));
|
||||
}
|
||||
|
||||
stream_executor::Event::Status TpuExecutor::PollForEventStatus(
|
||||
stream_executor::Event* event) {
|
||||
return stream_executor::Event::Status(TpuExecutor_PollForEventStatus(
|
||||
executor_, event_map_.at(event->implementation())));
|
||||
}
|
||||
|
||||
Status TpuExecutor::RecordEvent(Stream* stream,
|
||||
::stream_executor::Event* event) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_RecordEvent(executor_, stream_map().at(stream->implementation()),
|
||||
event_map_.at(event->implementation()),
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuExecutor::WaitForEvent(Stream* stream,
|
||||
::stream_executor::Event* event) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_WaitForEvent(executor_, stream_map().at(stream->implementation()),
|
||||
event_map_.at(event->implementation()),
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
// Implementations for Timer, Stream, Event
|
||||
// We need to map these implementations to internal equivalents -- thus we
|
||||
// allocate the internal Timer, Stream and Event operations here, and map
|
||||
// the implementations to the internal values. The "wrapper" interfaces are
|
||||
// responsible for deallocating the internal value when they are destroyed.
|
||||
|
||||
// Called by Timer::Timer
|
||||
std::unique_ptr<::stream_executor::internal::TimerInterface>
|
||||
TpuExecutor::GetTimerImplementation() {
|
||||
SE_Timer* tpu_timer = TpuTimer_New(executor_);
|
||||
auto ptr = absl::make_unique<TpuTimer>(tpu_timer);
|
||||
timer_map_[ptr.get()] = tpu_timer;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// Called by Stream::Stream
|
||||
std::unique_ptr<::stream_executor::internal::StreamInterface>
|
||||
TpuExecutor::GetStreamImplementation() {
|
||||
SE_Stream* tpu_stream = TpuStream_New(executor_);
|
||||
auto ptr = absl::make_unique<TpuStream>(tpu_stream);
|
||||
stream_map()[ptr.get()] = tpu_stream;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// Called by Event::Event
|
||||
std::unique_ptr<::stream_executor::internal::EventInterface>
|
||||
TpuExecutor::CreateEventImplementation() {
|
||||
SE_Event* tpu_event = TpuEvent_New(executor_);
|
||||
auto ptr = absl::make_unique<TpuEvent>(tpu_event);
|
||||
event_map_[ptr.get()] = tpu_event;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuExecutor_Allocate(executor_, size, memory_space);
|
||||
return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base);
|
||||
}
|
||||
|
||||
void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory);
|
||||
TpuExecutor_Deallocate(executor_, &se_base);
|
||||
}
|
||||
|
||||
void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory);
|
||||
TpuExecutor_Deallocate(executor_, &se_base);
|
||||
}
|
||||
|
||||
bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
|
||||
int64_t _free;
|
||||
int64_t _total;
|
||||
if (TpuExecutor_DeviceMemoryUsage(executor_, &_free, &_total)) {
|
||||
*free = _free;
|
||||
*total = _total;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
absl::optional<stream_executor::AllocatorStats>
|
||||
TpuExecutor::GetAllocatorStats() {
|
||||
SE_AllocatorStats c_stats;
|
||||
if (TpuExecutor_GetAllocatorStats(executor_, &c_stats)) {
|
||||
::stream_executor::AllocatorStats stats;
|
||||
stats.num_allocs = c_stats.num_allocs;
|
||||
stats.bytes_in_use = c_stats.bytes_in_use;
|
||||
stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
|
||||
stats.largest_alloc_size = c_stats.largest_alloc_size;
|
||||
if (c_stats.has_bytes_limit) {
|
||||
stats.bytes_limit = c_stats.bytes_limit;
|
||||
}
|
||||
stats.bytes_reserved = c_stats.bytes_reserved;
|
||||
stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
|
||||
if (c_stats.has_bytes_reservable_limit) {
|
||||
stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
|
||||
}
|
||||
stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
|
||||
return stats;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_WaitForInfeedReady(executor_, infeed_queue_index,
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_WaitForOutfeedReady(executor_, outfeed_queue_index,
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index,
|
||||
absl::Span<uint8> bytes, StatusCallback done) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_DequeueOutfeed(executor_, outfeed_queue_index, bytes.data(),
|
||||
bytes.size(), status.c_status);
|
||||
done(status.status());
|
||||
}
|
||||
|
||||
Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index,
|
||||
absl::Span<const uint8> bytes) {
|
||||
StatusHelper status;
|
||||
TpuExecutor_EnqueueInfeed(executor_, infeed_queue_index, bytes.data(),
|
||||
bytes.size(), status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
bool TpuExecutor::Memcpy(Stream* stream, void* host_dst,
|
||||
const ::stream_executor::DeviceMemoryBase& device_src,
|
||||
uint64 size) {
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
|
||||
return TpuExecutor_MemcpyToHost(executor_,
|
||||
stream_map().at(stream->implementation()),
|
||||
host_dst, &se_base, size);
|
||||
}
|
||||
|
||||
bool TpuExecutor::Memcpy(Stream* stream,
|
||||
::stream_executor::DeviceMemoryBase* device_dst,
|
||||
const void* host_src, uint64 size) {
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
|
||||
return TpuExecutor_MemcpyFromHost(executor_,
|
||||
stream_map().at(stream->implementation()),
|
||||
&se_base, host_src, size);
|
||||
}
|
||||
|
||||
Status TpuExecutor::SynchronousMemcpy(
|
||||
::stream_executor::DeviceMemoryBase* device_dst, const void* host_src,
|
||||
uint64 size) {
|
||||
StatusHelper status;
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst);
|
||||
TpuExecutor_SynchronousMemcpyFromHost(executor_, &se_base, host_src, size,
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuExecutor::SynchronousMemcpy(
|
||||
void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
|
||||
uint64 size) {
|
||||
StatusHelper status;
|
||||
SE_DeviceMemoryBase se_base =
|
||||
TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src);
|
||||
TpuExecutor_SynchronousMemcpyToHost(executor_, host_dst, &se_base, size,
|
||||
status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
Status TpuExecutor::SynchronousMemcpyDeviceToDevice(
|
||||
::stream_executor::DeviceMemoryBase* device_dst,
|
||||
const ::stream_executor::DeviceMemoryBase& device_src, uint64 size) {
|
||||
return ::stream_executor::port::UnimplementedError(
|
||||
"This operation not supported on TPU");
|
||||
}
|
||||
|
||||
bool TpuExecutor::MemcpyDeviceToDevice(
|
||||
Stream* stream, ::stream_executor::DeviceMemoryBase* gpu_dst,
|
||||
const ::stream_executor::DeviceMemoryBase& host_src, uint64 size) {
|
||||
LOG(FATAL) << __func__ << " not supported on TpuExecutor";
|
||||
}
|
||||
|
||||
struct HostCallbackContext {
|
||||
std::function<Status()> callback;
|
||||
};
|
||||
|
||||
SE_Status* HostCallbackTrampoline(void* ctx) {
|
||||
HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
|
||||
Status status = host_ctx->callback();
|
||||
SE_Status* c_status =
|
||||
TpuStatus_Create(status.code(), status.error_message().c_str());
|
||||
delete host_ctx;
|
||||
return c_status;
|
||||
}
|
||||
|
||||
bool TpuExecutor::HostCallback(Stream* stream,
|
||||
std::function<Status()> callback) {
|
||||
HostCallbackContext* ctx = new HostCallbackContext{callback};
|
||||
return TpuExecutor_HostCallback(executor_,
|
||||
stream_map().at(stream->implementation()),
|
||||
&HostCallbackTrampoline, ctx);
|
||||
}
|
||||
|
||||
TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
|
||||
TpuExecutor::CreateDeviceDescription() const {
|
||||
StatusHelper status;
|
||||
SE_DeviceDescription* description = TpuDeviceDescription_New();
|
||||
auto cleanup = tensorflow::gtl::MakeCleanup(
|
||||
[description]() { TpuDeviceDescription_Free(description); });
|
||||
TpuExecutor_CreateDeviceDescription(executor_, description, status.c_status);
|
||||
if (status.status().ok()) {
|
||||
stream_executor::internal::DeviceDescriptionBuilder builder;
|
||||
CHECK_NE(description->device_vendor, nullptr);
|
||||
builder.set_device_vendor(description->device_vendor);
|
||||
builder.set_name(description->name);
|
||||
builder.set_clock_rate_ghz(description->clock_rate_ghz);
|
||||
builder.set_core_count(description->core_count);
|
||||
builder.set_ecc_enabled(description->ecc_enabled);
|
||||
builder.set_device_memory_size(description->device_memory_size);
|
||||
builder.set_platform_version(description->platform_version);
|
||||
return builder.Build();
|
||||
}
|
||||
return status.status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
241
tensorflow/stream_executor/tpu/tpu_executor.h
Normal file
241
tensorflow/stream_executor/tpu/tpu_executor.h
Normal file
@ -0,0 +1,241 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/device_options.h"
|
||||
#include "tensorflow/stream_executor/event.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/temporary_device_memory.h"
|
||||
#include "tensorflow/stream_executor/timer.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
|
||||
public:
|
||||
using Status = ::stream_executor::port::Status;
|
||||
template <typename T>
|
||||
using StatusOr = ::stream_executor::port::StatusOr<T>;
|
||||
using StatusCallback = std::function<void(const Status&)>;
|
||||
using Stream = ::stream_executor::Stream;
|
||||
using Event = ::stream_executor::Event;
|
||||
using Timer = ::stream_executor::Timer;
|
||||
using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase;
|
||||
using StreamInterface = ::stream_executor::internal::StreamInterface;
|
||||
using StreamExecutorInterface =
|
||||
::stream_executor::internal::StreamExecutorInterface;
|
||||
|
||||
using EventMap =
|
||||
absl::flat_hash_map<stream_executor::internal::EventInterface*,
|
||||
SE_Event*>;
|
||||
using TimerMap =
|
||||
absl::flat_hash_map<stream_executor::internal::TimerInterface*,
|
||||
SE_Timer*>;
|
||||
|
||||
explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform,
|
||||
SE_StreamExecutor* executor)
|
||||
: platform_(platform), executor_(executor) {}
|
||||
|
||||
~TpuExecutor() override;
|
||||
|
||||
Status Init(int device_ordinal,
|
||||
::stream_executor::DeviceOptions device_options) override;
|
||||
|
||||
DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override;
|
||||
|
||||
StatusOr<DeviceMemoryBase> AllocateDeviceMemoryBase(uint64 size,
|
||||
int64 memory_space);
|
||||
|
||||
Status AllocateEvent(Event* event) override;
|
||||
|
||||
bool AllocateStream(Stream* stream) override;
|
||||
|
||||
bool AllocateTimer(Timer* timer) override;
|
||||
|
||||
Status BlockHostUntilDone(::stream_executor::Stream* stream) override;
|
||||
|
||||
Status BlockUntilDoneOrFailed();
|
||||
|
||||
StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
|
||||
CreateDeviceDescription() const override;
|
||||
|
||||
bool CreateStreamDependency(Stream* dependent, Stream* other) override;
|
||||
|
||||
void DeallocateStream(Stream* stream) override;
|
||||
|
||||
void Deallocate(const DeviceMemoryBase& memory);
|
||||
|
||||
void Deallocate(DeviceMemoryBase* memory) override;
|
||||
|
||||
Status DeallocateEvent(Event* event) override;
|
||||
|
||||
void DeallocateTimer(Timer* timer) override;
|
||||
|
||||
bool DeviceMemoryUsage(int64* free, int64* total) const override;
|
||||
|
||||
void DequeueOutfeed(int32 outfeed_queue_index, absl::Span<uint8> bytes,
|
||||
StatusCallback done);
|
||||
|
||||
Status EnqueueInfeed(int32 infeed_queue_index,
|
||||
absl::Span<const uint8> bytes);
|
||||
|
||||
absl::optional<stream_executor::AllocatorStats> GetAllocatorStats() override;
|
||||
|
||||
Status GetStatus(Stream* stream) override;
|
||||
|
||||
std::unique_ptr<::stream_executor::internal::StreamInterface>
|
||||
GetStreamImplementation() override;
|
||||
|
||||
std::unique_ptr<::stream_executor::internal::TimerInterface>
|
||||
GetTimerImplementation() override;
|
||||
|
||||
std::unique_ptr<::stream_executor::internal::EventInterface>
|
||||
CreateEventImplementation() override;
|
||||
|
||||
bool HostCallback(Stream* stream, std::function<Status()> callback) override;
|
||||
|
||||
bool Memcpy(Stream* stream, void* host_dst,
|
||||
const ::stream_executor::DeviceMemoryBase& device_src,
|
||||
uint64 size) override;
|
||||
|
||||
bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst,
|
||||
const void* host_src, uint64 size) override;
|
||||
|
||||
bool MemcpyDeviceToDevice(Stream* stream,
|
||||
::stream_executor::DeviceMemoryBase* gpu_dst,
|
||||
const ::stream_executor::DeviceMemoryBase& host_src,
|
||||
uint64 size) override;
|
||||
|
||||
void SyncAndForgetFailedStreams();
|
||||
bool SynchronizeAllActivity() override;
|
||||
|
||||
Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst,
|
||||
const void* host_src, uint64 size) override;
|
||||
Status SynchronousMemcpy(
|
||||
void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
|
||||
uint64 size) override;
|
||||
Status SynchronousMemcpyDeviceToDevice(
|
||||
::stream_executor::DeviceMemoryBase* device_dst,
|
||||
const ::stream_executor::DeviceMemoryBase& device_src,
|
||||
uint64 size) override;
|
||||
|
||||
int PlatformDeviceCount() override;
|
||||
|
||||
Event::Status PollForEventStatus(Event* event) override;
|
||||
Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override;
|
||||
Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override;
|
||||
|
||||
bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override;
|
||||
bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override;
|
||||
|
||||
Status WaitForInfeedReady(int32 infeed_queue_index);
|
||||
|
||||
Status WaitForOutfeedReady(int32 outfeed_queue_index);
|
||||
|
||||
const ::tensorflow::tpu::TpuPlatformInterface& platform() const override {
|
||||
return *platform_;
|
||||
}
|
||||
|
||||
::tensorflow::tpu::TpuPlatformInterface& platform() override {
|
||||
return *platform_;
|
||||
}
|
||||
|
||||
// TODO(henrytan): convert this to override once the base interface is changed
|
||||
// to TpuExecutorInterface.
|
||||
StatusOr<std::unique_ptr<
|
||||
tensorflow::tpu::TpuExecutorInterface::TemporaryDeviceMemory>>
|
||||
CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset,
|
||||
int64 size) override {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
}
|
||||
|
||||
// -- Unimplemented (stubbed out) methods.
|
||||
std::unique_ptr<stream_executor::internal::KernelInterface>
|
||||
CreateKernelImplementation() override {
|
||||
LOG(FATAL) << "Not yet implemented";
|
||||
}
|
||||
|
||||
stream_executor::SharedMemoryConfig GetDeviceSharedMemoryConfig() override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
|
||||
void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
|
||||
uint64 size) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
Status MemZero(Stream* stream, DeviceMemoryBase* location,
|
||||
uint64 size) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern,
|
||||
uint64 size) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
Status SetDeviceSharedMemoryConfig(
|
||||
stream_executor::SharedMemoryConfig config) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
void* HostMemoryAllocate(uint64 size) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
void HostMemoryDeallocate(void* mem) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
bool HostMemoryRegister(void* mem, uint64 size) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
bool HostMemoryUnregister(void* mem) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
Status SynchronousMemZero(DeviceMemoryBase* location, uint64 size) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
Status SynchronousMemSet(DeviceMemoryBase* location, int value,
|
||||
uint64 size) override {
|
||||
LOG(FATAL) << "not yet implemented";
|
||||
}
|
||||
|
||||
private:
|
||||
EventMap event_map_;
|
||||
TimerMap timer_map_;
|
||||
|
||||
TpuPlatform::StreamMap& stream_map() {
|
||||
return *(static_cast<TpuPlatform*>(platform_)->stream_map());
|
||||
}
|
||||
|
||||
::tensorflow::tpu::TpuPlatformInterface* platform_;
|
||||
SE_StreamExecutor* executor_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
|
293
tensorflow/stream_executor/tpu/tpu_executor_c_api.h
Normal file
293
tensorflow/stream_executor/tpu/tpu_executor_c_api.h
Normal file
@ -0,0 +1,293 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/tf_attrtype.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
typedef struct SE_Platform SE_Platform;
|
||||
typedef struct SE_StreamExecutor SE_StreamExecutor;
|
||||
typedef struct SE_Stream SE_Stream;
|
||||
typedef struct SE_Event SE_Event;
|
||||
typedef struct SE_Timer SE_Timer;
|
||||
typedef struct SE_Status SE_Status;
|
||||
|
||||
typedef struct SE_PlatformId {
|
||||
void* id; // aka stream_executor::Platform::Id
|
||||
} SE_PlatformId;
|
||||
typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig;
|
||||
typedef struct SE_DeviceOptions SE_DeviceOptions;
|
||||
typedef SE_Status* (*SE_StatusCallbackFn)(void*);
|
||||
|
||||
typedef struct SE_DeviceMemoryBase {
|
||||
void* opaque;
|
||||
uint64_t size;
|
||||
uint64_t payload;
|
||||
} SE_DeviceMemoryBase;
|
||||
|
||||
typedef struct SE_AllocatorStats {
|
||||
int64_t num_allocs;
|
||||
int64_t bytes_in_use;
|
||||
int64_t peak_bytes_in_use;
|
||||
int64_t largest_alloc_size;
|
||||
|
||||
bool has_bytes_limit;
|
||||
int64_t bytes_limit;
|
||||
|
||||
int64_t bytes_reserved;
|
||||
int64_t peak_bytes_reserved;
|
||||
|
||||
bool has_bytes_reservable_limit;
|
||||
int64_t bytes_reservable_limit;
|
||||
|
||||
int64_t largest_free_block_bytes;
|
||||
} SE_AllocatorStats;
|
||||
|
||||
typedef struct SE_DeviceDescription {
|
||||
char* device_vendor;
|
||||
char* platform_version;
|
||||
char* driver_version;
|
||||
char* runtime_version;
|
||||
char* pci_bus_id;
|
||||
char* name;
|
||||
|
||||
int64_t thread_dim_limit_x;
|
||||
int64_t thread_dim_limit_y;
|
||||
int64_t thread_dim_limit_z;
|
||||
int64_t block_dim_limit_x;
|
||||
int64_t block_dim_limit_y;
|
||||
int64_t block_dim_limit_z;
|
||||
|
||||
int64_t threads_per_core_limit;
|
||||
int64_t threads_per_block_limit;
|
||||
int64_t threads_per_warp;
|
||||
|
||||
int64_t registers_per_core_limit;
|
||||
int64_t registers_per_block_limit;
|
||||
|
||||
int64_t device_address_bits;
|
||||
int64_t device_memory_size;
|
||||
int64_t memory_bandwidth;
|
||||
|
||||
int64_t shared_memory_per_core;
|
||||
int64_t shared_memory_per_block;
|
||||
|
||||
float clock_rate_ghz;
|
||||
|
||||
int cuda_compute_capability_major;
|
||||
int cuda_compute_capability_minor;
|
||||
|
||||
int rocm_amdgpu_isa_version;
|
||||
|
||||
int numa_node;
|
||||
int core_count;
|
||||
bool ecc_enabled;
|
||||
} SE_DeviceDescription;
|
||||
|
||||
typedef struct XLA_TransferManager XLA_TransferManager;
|
||||
|
||||
typedef struct XLA_ComputationPlacer XLA_ComputationPlacer;
|
||||
|
||||
// Represents an XLA shape tree.
|
||||
// Shapes are flattened in default traversal order.
|
||||
typedef struct XLA_Shape {
|
||||
char* bytes;
|
||||
size_t size;
|
||||
} XLA_Shape;
|
||||
|
||||
// Represents a leaf node for a XLA shaped buffer.
|
||||
typedef struct XLA_ShapedBuffer {
|
||||
XLA_Shape on_host_shape;
|
||||
XLA_Shape on_device_shape;
|
||||
int device_ordinal;
|
||||
|
||||
SE_DeviceMemoryBase* bases;
|
||||
size_t count;
|
||||
} XLA_ShapedBuffer;
|
||||
|
||||
// Represents a leaf XLA literal.
|
||||
typedef struct XLA_Literal {
|
||||
char** buffers;
|
||||
size_t* sizes;
|
||||
size_t count;
|
||||
XLA_Shape shape;
|
||||
} XLA_Literal;
|
||||
|
||||
typedef void (*XLA_CallbackFn)(void*);
|
||||
typedef void (*XLA_StatusCallbackFn)(void*, SE_Status*);
|
||||
|
||||
extern "C" {
|
||||
|
||||
SE_Platform* TpuPlatform_New();
|
||||
void TpuPlatform_Free(SE_Platform* platform);
|
||||
void TpuPlatform_Initialize(SE_Platform* platform, size_t options_size,
|
||||
const char** options_key,
|
||||
const char** options_value, SE_Status* status);
|
||||
bool TpuPlatform_Initialized(SE_Platform* platform);
|
||||
SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform,
|
||||
SE_StreamExecutorConfig* config,
|
||||
SE_Status* status);
|
||||
SE_PlatformId TpuPlatform_Id(SE_Platform* platform);
|
||||
int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform);
|
||||
int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform);
|
||||
|
||||
void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal,
|
||||
SE_DeviceOptions* device_options, SE_Status* status);
|
||||
void TpuExecutor_Free(SE_StreamExecutor* executor);
|
||||
|
||||
int TpuExecutor_PlatformDeviceCount(SE_StreamExecutor* executor);
|
||||
|
||||
SE_DeviceMemoryBase TpuExecutor_Allocate(SE_StreamExecutor* executor,
|
||||
uint64_t size, int64_t memory_space);
|
||||
void TpuExecutor_Deallocate(SE_StreamExecutor* executor,
|
||||
SE_DeviceMemoryBase* memory);
|
||||
bool TpuExecutor_GetAllocatorStats(SE_StreamExecutor* executor,
|
||||
SE_AllocatorStats* stats);
|
||||
bool TpuExecutor_DeviceMemoryUsage(SE_StreamExecutor* executor, int64_t* free,
|
||||
int64_t* total);
|
||||
|
||||
bool TpuExecutor_AllocateStream(SE_StreamExecutor* executor, SE_Stream* stream);
|
||||
void TpuExecutor_DeallocateStream(SE_StreamExecutor* executor,
|
||||
SE_Stream* stream);
|
||||
bool TpuExecutor_CreateStreamDependency(SE_StreamExecutor* executor,
|
||||
SE_Stream* dependent, SE_Stream* other);
|
||||
void TpuExecutor_GetStatus(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_Status* status);
|
||||
|
||||
void TpuExecutor_AllocateEvent(SE_StreamExecutor* executor, SE_Event* event,
|
||||
SE_Status* status);
|
||||
void TpuExecutor_DeallocateEvent(SE_StreamExecutor* executor, SE_Event* event,
|
||||
SE_Status* status);
|
||||
int TpuExecutor_PollForEventStatus(SE_StreamExecutor* executor,
|
||||
SE_Event* event);
|
||||
void TpuExecutor_RecordEvent(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_Event* event, SE_Status* status);
|
||||
void TpuExecutor_WaitForEvent(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_Event* event, SE_Status* status);
|
||||
|
||||
bool TpuExecutor_AllocateTimer(SE_StreamExecutor* executor, SE_Timer* timer);
|
||||
void TpuExecutor_DeallocateTimer(SE_StreamExecutor* executor, SE_Timer* timer);
|
||||
bool TpuExecutor_StartTimer(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_Timer* timer);
|
||||
bool TpuExecutor_StopTimer(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_Timer* timer);
|
||||
|
||||
void TpuExecutor_SynchronousMemcpyToHost(SE_StreamExecutor* executor,
|
||||
void* host_dst,
|
||||
const SE_DeviceMemoryBase* device_src,
|
||||
uint64_t size, SE_Status* status);
|
||||
void TpuExecutor_SynchronousMemcpyFromHost(SE_StreamExecutor* executor,
|
||||
SE_DeviceMemoryBase* device_dst,
|
||||
const void* host_src, uint64_t size,
|
||||
SE_Status* status);
|
||||
bool TpuExecutor_MemcpyToHost(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
void* host_dst,
|
||||
const SE_DeviceMemoryBase* device_src,
|
||||
uint64_t size);
|
||||
|
||||
bool TpuExecutor_MemcpyFromHost(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_DeviceMemoryBase* device_dst,
|
||||
const void* host_src, uint64_t size);
|
||||
|
||||
void TpuExecutor_EnqueueInfeed(SE_StreamExecutor* executor,
|
||||
int32_t infeed_queue_index, const uint8_t* data,
|
||||
int64_t size, SE_Status* status);
|
||||
void TpuExecutor_DequeueOutfeed(SE_StreamExecutor* executor,
|
||||
int32_t outfeed_queue_index, uint8_t* data,
|
||||
int64_t size, SE_Status* status);
|
||||
void TpuExecutor_WaitForInfeedReady(SE_StreamExecutor* executor,
|
||||
int32_t infeed_queue_index,
|
||||
SE_Status* status);
|
||||
void TpuExecutor_WaitForOutfeedReady(SE_StreamExecutor* executor,
|
||||
int32_t outfeed_queue_index,
|
||||
SE_Status* status);
|
||||
|
||||
void TpuExecutor_BlockHostUntilDone(SE_StreamExecutor* executor,
|
||||
SE_Stream* stream, SE_Status* status);
|
||||
void TpuExecutor_BlockUntilDoneOrFailed(SE_StreamExecutor* executor,
|
||||
SE_Status* status);
|
||||
void TpuExecutor_SyncAndForgetFailedStreams(SE_StreamExecutor* executor);
|
||||
bool TpuExecutor_SynchronizeAllActivity(SE_StreamExecutor* executor);
|
||||
|
||||
SE_Stream* TpuStream_New(SE_StreamExecutor* parent);
|
||||
void TpuStream_Free(SE_Stream*);
|
||||
void* TpuStream_Stream(SE_Stream*);
|
||||
bool TpuStream_Status(SE_Stream*);
|
||||
|
||||
SE_Event* TpuEvent_New(SE_StreamExecutor* parent);
|
||||
void TpuEvent_Free(SE_Event*);
|
||||
|
||||
SE_Timer* TpuTimer_New(SE_StreamExecutor* parent);
|
||||
void TpuTimer_Free(SE_Timer*);
|
||||
int64_t TpuTimer_Nanoseconds(SE_Timer*);
|
||||
int64_t TpuTimer_Microseconds(SE_Timer*);
|
||||
|
||||
SE_Status* TpuStatus_New();
|
||||
SE_Status* TpuStatus_Create(int32_t code, const char* msg);
|
||||
void TpuStatus_Free(SE_Status* status);
|
||||
const char* TpuStatus_Message(SE_Status* status);
|
||||
int TpuStatus_Code(SE_Status* status);
|
||||
bool TpuStatus_Ok(SE_Status* status);
|
||||
|
||||
SE_StreamExecutorConfig* TpuStreamExecutorConfig_Default();
|
||||
void TpuStreamExecutorConfig_SetOrdinal(SE_StreamExecutorConfig*, int ordinal);
|
||||
void TpuStreamExecutorConfig_Free(SE_StreamExecutorConfig*);
|
||||
|
||||
SE_DeviceDescription* TpuDeviceDescription_New();
|
||||
void TpuDeviceDescription_Free(SE_DeviceDescription* description);
|
||||
void TpuExecutor_CreateDeviceDescription(SE_StreamExecutor* executor,
|
||||
SE_DeviceDescription* description,
|
||||
SE_Status* status);
|
||||
|
||||
SE_DeviceOptions* TpuExecutor_NewDeviceOptions(unsigned flags);
|
||||
void TpuExecutor_FreeDeviceOptions(SE_DeviceOptions* options);
|
||||
|
||||
bool TpuExecutor_HostCallback(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_StatusCallbackFn callback_fn, void* ctx);
|
||||
|
||||
XLA_TransferManager* TpuTransferManager_New();
|
||||
void TpuTransferManager_Free(XLA_TransferManager* manager);
|
||||
SE_PlatformId TpuTransferManager_PlatformId(XLA_TransferManager* manager);
|
||||
void TpuTransferManager_HostShapeToDeviceShape(XLA_TransferManager* manager,
|
||||
XLA_Shape* host_shape,
|
||||
XLA_Shape* device_shape);
|
||||
void TpuTransferManager_TransferLiteralToDeviceAsync(
|
||||
XLA_TransferManager* manager, SE_Stream* stream, XLA_Literal* literal,
|
||||
XLA_ShapedBuffer* device_buffer, SE_Status* status);
|
||||
void TpuTransferManager_TransferLiteralFromDevice(
|
||||
XLA_TransferManager* manager, SE_Stream* stream,
|
||||
XLA_ShapedBuffer* device_buffer, XLA_Literal* literal,
|
||||
XLA_StatusCallbackFn callback, void* ctx);
|
||||
|
||||
int64_t TpuTransferManager_GetByteSizeRequirement(XLA_TransferManager* manager,
|
||||
XLA_Shape* shape);
|
||||
void TpuTransferManager_WriteSingleTupleIndexTable(
|
||||
XLA_TransferManager* manager, SE_Stream* stream,
|
||||
SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape,
|
||||
SE_DeviceMemoryBase* region, SE_Status* status);
|
||||
|
||||
XLA_ComputationPlacer* TpuComputationPlacer_New();
|
||||
void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer);
|
||||
}
|
||||
|
||||
// extern "C"
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_
|
64
tensorflow/stream_executor/tpu/tpu_executor_interface.h
Normal file
64
tensorflow/stream_executor/tpu/tpu_executor_interface.h
Normal file
@ -0,0 +1,64 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/event.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/timer.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
namespace tpu {
|
||||
class TpuCore;
|
||||
} // namespace tpu
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuExecutorInterface
|
||||
: public ::stream_executor::internal::StreamExecutorInterface {
|
||||
public:
|
||||
using Status = ::stream_executor::port::Status;
|
||||
template <typename T>
|
||||
using StatusOr = ::stream_executor::port::StatusOr<T>;
|
||||
|
||||
class TemporaryDeviceMemory {
|
||||
public:
|
||||
virtual ~TemporaryDeviceMemory() {}
|
||||
virtual stream_executor::DeviceMemoryBase AsDeviceMemoryBase() const = 0;
|
||||
};
|
||||
|
||||
virtual StatusOr<std::unique_ptr<TemporaryDeviceMemory>>
|
||||
CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset,
|
||||
int64 size) {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
}
|
||||
|
||||
virtual const TpuPlatformInterface& platform() const {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
}
|
||||
|
||||
virtual TpuPlatformInterface& platform() { LOG(FATAL) << "Unimplemented."; }
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_
|
100
tensorflow/stream_executor/tpu/tpu_node_context.cc
Normal file
100
tensorflow/stream_executor/tpu/tpu_node_context.cc
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Initialize(
|
||||
int device_ordinal) {
|
||||
StatusHelper status;
|
||||
XLA_TpuNodeContext* node_context =
|
||||
TpuNodeContext_Create(device_ordinal, status.c_status);
|
||||
if (!status.status().ok()) {
|
||||
TpuNodeContext_Free(node_context);
|
||||
return status.status();
|
||||
}
|
||||
return std::make_unique<TpuNodeContext>(device_ordinal, node_context);
|
||||
}
|
||||
|
||||
TpuNodeContext::~TpuNodeContext() { TpuNodeContext_Free(node_context_); }
|
||||
|
||||
/* static */
|
||||
Status TpuNodeContext::StopChipHeartbeats() {
|
||||
StatusHelper status;
|
||||
TpuNodeContext_StopChipHeartbeats(status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status TpuNodeContext::CloseTpuHost() {
|
||||
StatusHelper status;
|
||||
TpuNodeContext_CloseTpuHost(status.c_status);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
/* static */
|
||||
tensorflow::tpu::TpuPlatformInterface* TpuNodeContext::platform() {
|
||||
return TpuPlatformInterface::GetRegisteredPlatform();
|
||||
}
|
||||
|
||||
/* static */
|
||||
stream_executor::DeviceMemoryAllocator* TpuNodeContext::memory_allocator() {
|
||||
static stream_executor::StreamExecutorMemoryAllocator* memory_allocator =
|
||||
new stream_executor::StreamExecutorMemoryAllocator(
|
||||
platform(),
|
||||
xla::PlatformUtil::GetStreamExecutors(platform()).ValueOrDie());
|
||||
return memory_allocator;
|
||||
}
|
||||
|
||||
/* static */
|
||||
xla::Backend* TpuNodeContext::backend() {
|
||||
static xla::Backend* backend =
|
||||
xla::Backend::CreateBackend(
|
||||
xla::BackendOptions().set_platform(platform()))
|
||||
.ValueOrDie()
|
||||
.release();
|
||||
return backend;
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
|
||||
int device_ordinal) {
|
||||
return backend()->BorrowStream(device_ordinal);
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
|
||||
stream_executor::StreamExecutor* executor) {
|
||||
return backend()->BorrowStream(executor);
|
||||
}
|
||||
|
||||
/* static */
|
||||
xla::TransferManager* TpuNodeContext::transfer_manager() {
|
||||
return xla::TransferManager::GetForPlatform(platform()).ValueOrDie();
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
89
tensorflow/stream_executor/tpu/tpu_node_context.h
Normal file
89
tensorflow/stream_executor/tpu/tpu_node_context.h
Normal file
@ -0,0 +1,89 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuNodeContext final {
|
||||
public:
|
||||
using Status = stream_executor::port::Status;
|
||||
template <typename T>
|
||||
using StatusOr = stream_executor::port::StatusOr<T>;
|
||||
|
||||
static StatusOr<std::unique_ptr<TpuNodeContext>> Initialize(
|
||||
int device_ordinal);
|
||||
|
||||
explicit TpuNodeContext(int device_ordinal, XLA_TpuNodeContext* node_context)
|
||||
: device_ordinal_(device_ordinal), node_context_(node_context) {
|
||||
CHECK_NE(node_context, nullptr);
|
||||
}
|
||||
~TpuNodeContext();
|
||||
|
||||
TpuNodeContext(const TpuNodeContext&) = delete;
|
||||
TpuNodeContext& operator=(const TpuNodeContext&) = delete;
|
||||
|
||||
static Status StopChipHeartbeats();
|
||||
|
||||
static Status CloseTpuHost();
|
||||
|
||||
static tensorflow::tpu::TpuPlatformInterface* platform();
|
||||
|
||||
static stream_executor::DeviceMemoryAllocator* memory_allocator();
|
||||
|
||||
static xla::TransferManager* transfer_manager();
|
||||
|
||||
static xla::Backend* backend();
|
||||
|
||||
static StatusOr<xla::StreamPool::Ptr> BorrowStream(int device_ordinal);
|
||||
|
||||
static StatusOr<xla::StreamPool::Ptr> BorrowStream(
|
||||
stream_executor::StreamExecutor* executor);
|
||||
|
||||
stream_executor::StreamExecutor* stream_executor() {
|
||||
LOG(FATAL) << "Not implemented yet.";
|
||||
}
|
||||
|
||||
std::string tensor_core_location() { LOG(FATAL) << "Not implemented yet."; }
|
||||
|
||||
int index_on_host() { LOG(FATAL) << "Not implemented yet."; }
|
||||
|
||||
int device_ordinal() const { return device_ordinal_; }
|
||||
|
||||
private:
|
||||
const int device_ordinal_;
|
||||
XLA_TpuNodeContext* const node_context_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_
|
29
tensorflow/stream_executor/tpu/tpu_node_context_c_api.h
Normal file
29
tensorflow/stream_executor/tpu/tpu_node_context_c_api.h
Normal file
@ -0,0 +1,29 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
|
||||
|
||||
XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
|
||||
SE_Status* status);
|
||||
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
|
||||
|
||||
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
|
||||
void TpuNodeContext_CloseTpuHost(SE_Status* status);
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_
|
125
tensorflow/stream_executor/tpu/tpu_platform.cc
Normal file
125
tensorflow/stream_executor/tpu/tpu_platform.cc
Normal file
@ -0,0 +1,125 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
PLATFORM_DEFINE_ID(TpuPlatform::kId);
|
||||
TpuPlatform* tpu_registered_platform = nullptr;
|
||||
|
||||
using Status = ::stream_executor::port::Status;
|
||||
template <typename T>
|
||||
using StatusOr = ::stream_executor::port::StatusOr<T>;
|
||||
|
||||
TpuPlatform::TpuPlatform() { platform_ = TpuPlatform_New(); }
|
||||
|
||||
TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
|
||||
return tpu_registered_platform;
|
||||
}
|
||||
|
||||
Status TpuPlatform::Initialize(
|
||||
const std::map<std::string, std::string>& platform_options) {
|
||||
StatusHelper status;
|
||||
|
||||
size_t options_size = platform_options.size();
|
||||
const char** options_key =
|
||||
static_cast<const char**>(malloc(sizeof(const char*) * options_size));
|
||||
const char** options_value =
|
||||
static_cast<const char**>(malloc(sizeof(const char*) * options_size));
|
||||
|
||||
size_t i = 0;
|
||||
for (const auto& option : platform_options) {
|
||||
options_key[i] = option.first.c_str();
|
||||
options_value[i] = option.second.c_str();
|
||||
i++;
|
||||
}
|
||||
|
||||
TpuPlatform_Initialize(platform_, options_size, options_key, options_value,
|
||||
status.c_status);
|
||||
|
||||
free(options_key);
|
||||
free(options_value);
|
||||
|
||||
return status.status();
|
||||
}
|
||||
|
||||
TpuPlatform::~TpuPlatform() { TpuPlatform_Free(platform_); }
|
||||
|
||||
int TpuPlatform::VisibleDeviceCount() const {
|
||||
return TpuPlatform_VisibleDeviceCount(platform_);
|
||||
}
|
||||
|
||||
StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
|
||||
const ::stream_executor::StreamExecutorConfig& config) {
|
||||
return executor_cache_.GetOrCreate(
|
||||
config, [&]() { return GetUncachedExecutor(config); });
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
|
||||
TpuPlatform::GetUncachedExecutor(
|
||||
const ::stream_executor::StreamExecutorConfig& config) {
|
||||
SE_StreamExecutorConfig* c_config = TpuStreamExecutorConfig_Default();
|
||||
|
||||
TpuStreamExecutorConfig_SetOrdinal(c_config, config.ordinal);
|
||||
|
||||
StatusHelper status;
|
||||
SE_StreamExecutor* executor =
|
||||
TpuPlatform_GetExecutor(platform_, c_config, status.c_status);
|
||||
TpuStreamExecutorConfig_Free(c_config);
|
||||
if (!status.ok()) {
|
||||
return status.status();
|
||||
}
|
||||
return std::make_unique<stream_executor::StreamExecutor>(
|
||||
this, absl::make_unique<tensorflow::TpuExecutor>(this, executor),
|
||||
config.ordinal);
|
||||
}
|
||||
|
||||
::stream_executor::Platform::Id TpuPlatform::id() const {
|
||||
return TpuPlatform::kId;
|
||||
}
|
||||
|
||||
const std::string& TpuPlatform::Name() const {
|
||||
static std::string* name = new std::string(kName);
|
||||
return *name;
|
||||
}
|
||||
|
||||
int64 TpuPlatform::TpuMemoryLimit() {
|
||||
return TpuPlatform_TpuMemoryLimit(platform_);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
void RegisterTpuPlatform() {
|
||||
tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
|
||||
std::unique_ptr<stream_executor::Platform> platform(
|
||||
tensorflow::tpu_registered_platform);
|
||||
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
|
||||
std::move(platform)));
|
||||
}
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(tpu_platform, RegisterTpuPlatform());
|
||||
|
||||
// Note that module initialization sequencing is not supported in the
|
||||
// open-source project, so this will be a no-op there.
|
||||
REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager);
|
||||
REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener,
|
||||
tpu_platform);
|
121
tensorflow/stream_executor/tpu/tpu_platform.h
Normal file
121
tensorflow/stream_executor/tpu/tpu_platform.h
Normal file
@ -0,0 +1,121 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/executor_cache.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
||||
public:
|
||||
using StreamMap =
|
||||
absl::flat_hash_map<stream_executor::internal::StreamInterface*,
|
||||
SE_Stream*>;
|
||||
|
||||
static const ::stream_executor::Platform::Id kId;
|
||||
static constexpr char kName[] = "TPU";
|
||||
|
||||
using Status = ::stream_executor::port::Status;
|
||||
template <typename T>
|
||||
using StatusOr = ::stream_executor::port::StatusOr<T>;
|
||||
|
||||
TpuPlatform();
|
||||
|
||||
~TpuPlatform() override;
|
||||
|
||||
static TpuPlatform* GetRegisteredPlatform();
|
||||
|
||||
Id id() const override;
|
||||
|
||||
const std::string& Name() const override;
|
||||
|
||||
int VisibleDeviceCount() const override;
|
||||
|
||||
int64 TpuMemoryLimit() override;
|
||||
|
||||
bool Initialized() const override {
|
||||
return TpuPlatform_Initialized(platform_);
|
||||
}
|
||||
|
||||
Status Initialize(
|
||||
const std::map<std::string, std::string>& platform_options) override;
|
||||
|
||||
Status Reset() override { return Reset(false); }
|
||||
|
||||
Status Reset(bool only_tear_down) override {
|
||||
LOG(FATAL) << "Not yet implemented";
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
|
||||
DescriptionForDevice(int ordinal) const override {
|
||||
LOG(FATAL) << "Not yet implemented";
|
||||
}
|
||||
|
||||
StatusOr<::stream_executor::StreamExecutor*> ExecutorForDevice(
|
||||
int ordinal) override {
|
||||
stream_executor::StreamExecutorConfig config;
|
||||
config.ordinal = ordinal;
|
||||
return GetExecutor(config);
|
||||
}
|
||||
|
||||
StatusOr<::stream_executor::StreamExecutor*>
|
||||
ExecutorForDeviceWithPluginConfig(
|
||||
int ordinal,
|
||||
const ::stream_executor::PluginConfig& plugin_config) override {
|
||||
stream_executor::StreamExecutorConfig config;
|
||||
config.ordinal = ordinal;
|
||||
config.plugin_config = plugin_config;
|
||||
return GetExecutor(config);
|
||||
}
|
||||
|
||||
StatusOr<::stream_executor::StreamExecutor*> GetExecutor(
|
||||
const ::stream_executor::StreamExecutorConfig& config) override;
|
||||
|
||||
StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
|
||||
GetUncachedExecutor(
|
||||
const ::stream_executor::StreamExecutorConfig& config) override;
|
||||
|
||||
void RegisterTraceListener(
|
||||
std::unique_ptr<stream_executor::TraceListener> listener) override {
|
||||
LOG(FATAL) << "Not yet implemented";
|
||||
}
|
||||
|
||||
void UnregisterTraceListener(
|
||||
stream_executor::TraceListener* listener) override {
|
||||
LOG(FATAL) << "Not yet implemented";
|
||||
}
|
||||
|
||||
StreamMap* stream_map() { return &stream_map_; }
|
||||
|
||||
private:
|
||||
SE_Platform* platform_;
|
||||
|
||||
stream_executor::ExecutorCache executor_cache_;
|
||||
StreamMap stream_map_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
|
63
tensorflow/stream_executor/tpu/tpu_platform_interface.cc
Normal file
63
tensorflow/stream_executor/tpu/tpu_platform_interface.cc
Normal file
@ -0,0 +1,63 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
/* static */
|
||||
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
|
||||
// Prefer TpuPlatform if it's registered.
|
||||
auto status_or_tpu_platform =
|
||||
stream_executor::MultiPlatformManager::PlatformWithName("TPU");
|
||||
if (status_or_tpu_platform.ok()) {
|
||||
return static_cast<TpuPlatformInterface*>(
|
||||
status_or_tpu_platform.ValueOrDie());
|
||||
}
|
||||
if (status_or_tpu_platform.status().code() != error::NOT_FOUND) {
|
||||
LOG(WARNING) << "Error when getting the TPU platform: "
|
||||
<< status_or_tpu_platform.status();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Use any other registered TPU platform.
|
||||
auto status_or_other_tpu_platforms =
|
||||
stream_executor::MultiPlatformManager::PlatformsWithFilter(
|
||||
[](const stream_executor::Platform* platform) {
|
||||
return dynamic_cast<const TpuPlatformInterface*>(platform) !=
|
||||
nullptr;
|
||||
});
|
||||
if (!status_or_other_tpu_platforms.ok()) {
|
||||
LOG(WARNING) << "Error when getting other TPU platforms: "
|
||||
<< status_or_tpu_platform.status();
|
||||
return nullptr;
|
||||
}
|
||||
auto other_tpu_platforms = status_or_other_tpu_platforms.ValueOrDie();
|
||||
if (!other_tpu_platforms.empty()) {
|
||||
LOG(WARNING) << other_tpu_platforms.size()
|
||||
<< " TPU platforms registered, selecting "
|
||||
<< other_tpu_platforms[0]->Name();
|
||||
return static_cast<TpuPlatformInterface*>(other_tpu_platforms[0]);
|
||||
}
|
||||
|
||||
LOG(WARNING) << "No TPU platform registered";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
44
tensorflow/stream_executor/tpu/tpu_platform_interface.h
Normal file
44
tensorflow/stream_executor/tpu/tpu_platform_interface.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuPlatformInterface : public stream_executor::Platform {
|
||||
public:
|
||||
using Status = stream_executor::port::Status;
|
||||
|
||||
// Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are
|
||||
// registered, finds the most suitable one. Returns nullptr if no TPU platform
|
||||
// is registered or an error occurred.
|
||||
static TpuPlatformInterface* GetRegisteredPlatform();
|
||||
|
||||
virtual Status Reset() { return Reset(false); }
|
||||
|
||||
virtual Status Reset(bool only_tear_down) = 0;
|
||||
|
||||
virtual int64 TpuMemoryLimit() = 0;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_
|
40
tensorflow/stream_executor/tpu/tpu_stream.h
Normal file
40
tensorflow/stream_executor/tpu/tpu_stream.h
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
||||
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
class TpuStream : public stream_executor::internal::StreamInterface {
|
||||
public:
|
||||
explicit TpuStream(SE_Stream* stream) : stream_(stream) {}
|
||||
~TpuStream() override { TpuStream_Free(stream_); }
|
||||
|
||||
private:
|
||||
SE_Stream* stream_;
|
||||
};
|
||||
|
||||
class TpuEvent : public ::stream_executor::internal::EventInterface {
|
||||
public:
|
||||
explicit TpuEvent(SE_Event* event) : event_(event) {}
|
||||
~TpuEvent() override { TpuEvent_Free(event_); }
|
||||
|
||||
private:
|
||||
SE_Event* event_;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_
|
30
tensorflow/stream_executor/tpu/tpu_stream_interface.h
Normal file
30
tensorflow/stream_executor/tpu/tpu_stream_interface.h
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
|
||||
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuStreamInterface : public ::stream_executor::internal::StreamInterface {
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_
|
38
tensorflow/stream_executor/tpu/tpu_timer.h
Normal file
38
tensorflow/stream_executor/tpu/tpu_timer.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TpuTimer : public ::stream_executor::internal::TimerInterface {
|
||||
public:
|
||||
explicit TpuTimer(SE_Timer* timer) : timer_(timer) {}
|
||||
~TpuTimer() override { TpuTimer_Free(timer_); }
|
||||
uint64 Microseconds() const override { return TpuTimer_Microseconds(timer_); }
|
||||
uint64 Nanoseconds() const override { return TpuTimer_Nanoseconds(timer_); }
|
||||
|
||||
private:
|
||||
SE_Timer* timer_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_
|
167
tensorflow/stream_executor/tpu/tpu_transfer_manager.cc
Normal file
167
tensorflow/stream_executor/tpu/tpu_transfer_manager.cc
Normal file
@ -0,0 +1,167 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_conversions.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using Status = stream_executor::port::Status;
|
||||
template <typename T>
|
||||
using StatusOr = stream_executor::port::StatusOr<T>;
|
||||
|
||||
TpuTransferManager::TpuTransferManager() {
|
||||
manager_ = TpuTransferManager_New();
|
||||
}
|
||||
|
||||
TpuTransferManager::~TpuTransferManager() { TpuTransferManager_Free(manager_); }
|
||||
|
||||
stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
|
||||
return TpuPlatform::kId;
|
||||
}
|
||||
|
||||
xla::Shape TpuTransferManager::HostShapeToDeviceShape(
|
||||
const xla::Shape& host_shape) const {
|
||||
XLA_Shape c_host_shape;
|
||||
XLA_Shape c_device_shape;
|
||||
|
||||
TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape);
|
||||
|
||||
TpuTransferManager_HostShapeToDeviceShape(manager_, &c_host_shape,
|
||||
&c_device_shape);
|
||||
xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape);
|
||||
TpuConversions::CShapeCleanup(&c_host_shape);
|
||||
TpuConversions::CShapeCleanup(&c_device_shape);
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
Status TpuTransferManager::TransferLiteralToDeviceAsync(
|
||||
stream_executor::Stream* stream, const xla::LiteralSlice& literal,
|
||||
const xla::ShapedBuffer& device_buffer,
|
||||
const TransferMetadata* transfer_metadata) {
|
||||
StatusHelper status;
|
||||
|
||||
XLA_Literal c_literal;
|
||||
TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
|
||||
|
||||
XLA_ShapedBuffer c_device_buffer;
|
||||
TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
|
||||
&c_device_buffer);
|
||||
|
||||
TpuTransferManager_TransferLiteralToDeviceAsync(
|
||||
manager_,
|
||||
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
||||
stream->implementation()),
|
||||
&c_literal, &c_device_buffer, status.c_status);
|
||||
TpuConversions::CShapedBufferCleanup(&c_device_buffer);
|
||||
TpuConversions::CLiteralCleanup(&c_literal);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
struct TransferFromDeviceState {
|
||||
std::atomic<int64_t> remaining_transfers;
|
||||
StatusHelper status_helper;
|
||||
std::function<void(Status)> done;
|
||||
|
||||
void TransferFinished(SE_Status* status) {
|
||||
if (!TpuStatus_Ok(status) && TpuStatus_Ok(status_helper.c_status)) {
|
||||
status_helper.c_status = status;
|
||||
} else {
|
||||
TpuStatus_Free(status);
|
||||
}
|
||||
|
||||
if (--remaining_transfers == 0) {
|
||||
done(status_helper.status());
|
||||
delete this;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void TransferLiteralFromDeviceTrampoline(void* ctx, SE_Status* status) {
|
||||
reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
|
||||
}
|
||||
|
||||
void TpuTransferManager::TransferLiteralFromDevice(
|
||||
stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
|
||||
xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
|
||||
const TransferMetadata* transfer_metadata) {
|
||||
TransferFromDeviceState* state = new TransferFromDeviceState;
|
||||
state->remaining_transfers = 1;
|
||||
state->done = done;
|
||||
XLA_ShapedBuffer c_device_buffer;
|
||||
TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer,
|
||||
&c_device_buffer);
|
||||
XLA_Literal c_literal;
|
||||
TpuConversions::XLALiteralToCLiteral(literal, &c_literal);
|
||||
|
||||
TpuTransferManager_TransferLiteralFromDevice(
|
||||
manager_,
|
||||
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
||||
stream->implementation()),
|
||||
&c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state);
|
||||
TpuConversions::CShapedBufferCleanup(&c_device_buffer);
|
||||
TpuConversions::CLiteralCleanup(&c_literal);
|
||||
}
|
||||
|
||||
int64 TpuTransferManager::GetByteSizeRequirement(
|
||||
const xla::Shape& shape) const {
|
||||
XLA_Shape c_shape;
|
||||
TpuConversions::XlaShapeToCShape(shape, &c_shape);
|
||||
|
||||
int64 size_in_bytes =
|
||||
TpuTransferManager_GetByteSizeRequirement(manager_, &c_shape);
|
||||
|
||||
TpuConversions::CShapeCleanup(&c_shape);
|
||||
return size_in_bytes;
|
||||
}
|
||||
|
||||
Status TpuTransferManager::WriteSingleTupleIndexTable(
|
||||
stream_executor::Stream* stream,
|
||||
absl::Span<const stream_executor::DeviceMemoryBase> elements,
|
||||
const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) {
|
||||
CHECK_GT(elements.size(), 0);
|
||||
SE_DeviceMemoryBase* elements_bases =
|
||||
new SE_DeviceMemoryBase[elements.size()];
|
||||
for (int i = 0; i < elements.size(); i++) {
|
||||
elements_bases[i] =
|
||||
SE_DeviceMemoryBase{const_cast<void*>(elements[i].opaque()),
|
||||
elements[i].size(), elements[i].payload()};
|
||||
}
|
||||
XLA_Shape c_shape;
|
||||
TpuConversions::XlaShapeToCShape(shape, &c_shape);
|
||||
SE_DeviceMemoryBase region_base{region->opaque(), region->size(),
|
||||
region->payload()};
|
||||
StatusHelper status;
|
||||
|
||||
TpuTransferManager_WriteSingleTupleIndexTable(
|
||||
manager_,
|
||||
TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
|
||||
stream->implementation()),
|
||||
elements_bases, elements.size(), &c_shape, ®ion_base, status.c_status);
|
||||
|
||||
delete[] elements_bases;
|
||||
TpuConversions::CShapeCleanup(&c_shape);
|
||||
return status.status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
83
tensorflow/stream_executor/tpu/tpu_transfer_manager.h
Normal file
83
tensorflow/stream_executor/tpu/tpu_transfer_manager.h
Normal file
@ -0,0 +1,83 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TpuTransferManager : public xla::TransferManager {
|
||||
public:
|
||||
TpuTransferManager();
|
||||
~TpuTransferManager() override;
|
||||
|
||||
using Status = stream_executor::port::Status;
|
||||
template <typename T>
|
||||
using StatusOr = stream_executor::port::StatusOr<T>;
|
||||
|
||||
stream_executor::Platform::Id PlatformId() const override;
|
||||
|
||||
xla::Shape HostShapeToDeviceShape(
|
||||
const xla::Shape& host_shape) const override;
|
||||
|
||||
Status TransferLiteralToDeviceAsync(
|
||||
stream_executor::Stream* stream, const xla::LiteralSlice& literal,
|
||||
const xla::ShapedBuffer& device_buffer,
|
||||
const TransferMetadata* transfer_metadata) override;
|
||||
|
||||
void TransferLiteralFromDevice(
|
||||
stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
|
||||
xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
|
||||
const TransferMetadata* transfer_metadata) override;
|
||||
|
||||
Status TransferLiteralToInfeed(stream_executor::StreamExecutor* executor,
|
||||
const xla::LiteralSlice& literal) override {
|
||||
LOG(FATAL) << "Not yet implemented";
|
||||
}
|
||||
|
||||
Status TransferLiteralFromOutfeed(
|
||||
stream_executor::StreamExecutor* executor,
|
||||
const xla::Shape& literal_shape,
|
||||
xla::MutableBorrowingLiteral literal) override {
|
||||
LOG(FATAL) << "Not yet implemented";
|
||||
}
|
||||
|
||||
Status ResetDevices(
|
||||
absl::Span<stream_executor::StreamExecutor* const> executor) override {
|
||||
LOG(FATAL) << "Not yet implemented";
|
||||
}
|
||||
|
||||
int64 GetByteSizeRequirement(const xla::Shape& shape) const override;
|
||||
|
||||
Status WriteSingleTupleIndexTable(
|
||||
stream_executor::Stream* stream,
|
||||
absl::Span<const stream_executor::DeviceMemoryBase> elements,
|
||||
const xla::Shape& shape,
|
||||
stream_executor::DeviceMemoryBase* region) override;
|
||||
|
||||
private:
|
||||
XLA_TransferManager* manager_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
|
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static std::unique_ptr<xla::TransferManager> CreateTpuTransferManager() {
|
||||
return std::make_unique<TpuTransferManager>();
|
||||
}
|
||||
|
||||
static bool InitModule() {
|
||||
xla::TransferManager::RegisterTransferManager(TpuPlatform::kId,
|
||||
CreateTpuTransferManager);
|
||||
return true;
|
||||
}
|
||||
static bool module_initialized = InitModule();
|
||||
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user