diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD new file mode 100644 index 00000000000..ef9e4a0a41e --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/BUILD @@ -0,0 +1,55 @@ +# Contains graph rewrites for TPU runtimes and optimizations. + +package( + default_visibility = [ + "//tensorflow/core/tpu:__subpackages__", + "//tensorflow/stream_executor/tpu:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "distributed_tpu_configuration_rewrite_registration", + srcs = ["distributed_tpu_configuration_rewrite_registration.cc"], + deps = [ + ":distributed_tpu_configuration_rewrite_pass", + "//tensorflow/core:core_cpu", + ], + alwayslink = 1, +) + +cc_library( + name = "distributed_tpu_configuration_rewrite_pass", + srcs = [ + "distributed_tpu_configuration_rewrite_pass.cc", + ], + hdrs = [ + "distributed_tpu_configuration_rewrite_pass.h", + ], + deps = [ + ":distributed_tpu_rewrite_helpers", + "//tensorflow/cc:scope", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core/protobuf/tpu:topology_proto_cc", + "//tensorflow/core/tpu:tpu_init_mode", + "//tensorflow/core/tpu/kernels:tpu_compile_op_options", + ], +) + +cc_library( + name = "distributed_tpu_rewrite_helpers", + srcs = ["distributed_tpu_rewrite_helpers.cc"], + hdrs = ["distributed_tpu_rewrite_helpers.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/tpu:tpu_defs", + ], +) diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.cc new file mode 100644 index 00000000000..3b1e9d79705 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.cc @@ -0,0 +1,402 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Configuration for distributed TPU jobs + +#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h" + +#include + +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h" +#include "tensorflow/core/tpu/tpu_init_mode.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" + +namespace tensorflow { +namespace { + +constexpr char kIdentityOp[] = "Identity"; +constexpr char kConfigureOp[] = "ConfigureDistributedTPU"; +constexpr char kInternalConfigureOp[] = "_ConfigureDistributedTPU"; +constexpr char kWaitOp[] = "_WaitForDistributedTPU"; +constexpr char kHostConfigureOp[] = "_InitializeHostForDistributedTPU"; +constexpr char kGlobalTPUArrayOp[] = "_SetGlobalTPUArray"; +constexpr char kShutdownOp[] = "ShutdownDistributedTPU"; +constexpr char kInternalShutdownOp[] = "_ShutdownDistributedTPU"; +constexpr char kHostDisconnectOp[] = "_DisconnectHostFromDistributedTPUSystem"; +constexpr char kEmbeddingConfigurationAttr[] = "embedding_config"; +constexpr int kDefaultStartupTimeout = 20; + +Status AddConfigurationNode(const string& configuration_device_name, + int number_of_hosts, Graph* graph, + bool enable_whole_mesh_compilations, + Node** configuration_node) { + NodeDef config_def; + config_def.set_name(graph->NewName("configure_distributed_tpu")); + config_def.set_op(kInternalConfigureOp); + config_def.set_device(configuration_device_name); + AddNodeAttr("N", number_of_hosts, &config_def); + AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations, + &config_def); + // TODO(shikharagarwal): Fill with appropriate original node debug info. + + Status status; + *configuration_node = graph->AddNode(config_def, &status); + if (!status.ok()) { + return status; + } + (*configuration_node)->set_assigned_device_name(configuration_device_name); + return Status::OK(); +} + +Status AddHostConfigNode(const string& host_device_name, + Node* configuration_node, Graph* graph, + bool enable_whole_mesh_compilations, + Node** host_configuration_node) { + NodeDef host_config_def; + host_config_def.set_name(graph->NewName("configure_tpu_host")); + host_config_def.set_op(kHostConfigureOp); + host_config_def.set_device(host_device_name); + AddNodeAttr("enable_whole_mesh_compilations", enable_whole_mesh_compilations, + &host_config_def); + MergeDebugInfo(NodeDebugInfo(configuration_node->def()), &host_config_def); + + Status status; + *host_configuration_node = graph->AddNode(host_config_def, &status); + if (!status.ok()) { + return status; + } + (*host_configuration_node)->set_assigned_device_name(host_device_name); + graph->AddEdge(configuration_node, 0, *host_configuration_node, 0); + return Status::OK(); +} + +Status AddWaitNode(const string& configuration_device_name, + const std::vector& host_configuration_nodes, + Graph* graph, Node** wait_node) { + NodeDef wait_def; + wait_def.set_name(graph->NewName("wait_for_distributed_tpu_system")); + wait_def.set_op(kWaitOp); + wait_def.set_device(configuration_device_name); + AddNodeAttr("N", static_cast(host_configuration_nodes.size()), + &wait_def); + AddNodeAttr("startup_timeout_sec", kDefaultStartupTimeout, &wait_def); + if (!host_configuration_nodes.empty()) { + MergeDebugInfo(NodeDebugInfo(host_configuration_nodes[0]->def()), + &wait_def); + } + + Status status; + *wait_node = graph->AddNode(wait_def, &status); + if (!status.ok()) { + return status; + } + (*wait_node)->set_assigned_device_name(configuration_device_name); + // Get the inputs from the host configuration nodes. + for (int i = 0; i < host_configuration_nodes.size(); ++i) { + graph->AddEdge(host_configuration_nodes[i], 0, *wait_node, i); + } + return Status::OK(); +} + +Status AddGlobalTPUArrayNode(const string& host_device_name, Node* wait_node, + Graph* graph, Node** global_tpu_array_node) { + NodeDef global_tpu_array_def; + global_tpu_array_def.set_name(graph->NewName("set_global_tpu_array")); + global_tpu_array_def.set_op(kGlobalTPUArrayOp); + global_tpu_array_def.set_device(host_device_name); + MergeDebugInfo(NodeDebugInfo(wait_node->def()), &global_tpu_array_def); + + Status status; + *global_tpu_array_node = graph->AddNode(global_tpu_array_def, &status); + if (!status.ok()) { + return status; + } + (*global_tpu_array_node)->set_assigned_device_name(host_device_name); + graph->AddEdge(wait_node, 0, *global_tpu_array_node, 0); + return Status::OK(); +} + +Status AddSynchronizationNode( + const NodeDef& sync_node_def, const string& device_name, + const std::vector& global_array_id_nodes, Node* wait_node, + const std::vector& + output_dependencies, + Graph* graph) { + NodeDef sync_def; + sync_def.set_name(sync_node_def.name()); + sync_def.set_op(kIdentityOp); + sync_def.set_device(device_name); + AddNodeAttr("T", DT_STRING, &sync_def); + MergeDebugInfo(NodeDebugInfo(sync_node_def), &sync_def); + + Status status; + Node* sync_node = graph->AddNode(sync_def, &status); + if (!status.ok()) { + return status; + } + sync_node->set_assigned_device_name(device_name); + // Add control edges from the global array id nodes. + for (auto node : global_array_id_nodes) { + graph->AddControlEdge(node, sync_node); + } + // Forward the data from the wait node. + graph->AddEdge(wait_node, 0, sync_node, 0); + // Replace the output edges. + for (const DistributedTPURewriteHelpers::OutputDependency& dep : + output_dependencies) { + if (dep.dst_input == Graph::kControlSlot) { + graph->AddControlEdge(sync_node, dep.dst); + } else { + graph->AddEdge(sync_node, dep.src_output, dep.dst, dep.dst_input); + } + } + return Status::OK(); +} + + +Status AddShutdownNode( + const NodeDef& shutdown_node_def, const string& shutdown_device_name, + const std::vector& + output_dependencies, + Graph* graph, Node** shutdown_node) { + NodeDef shutdown_def; + shutdown_def.set_name(shutdown_node_def.name()); + shutdown_def.set_op(kInternalShutdownOp); + shutdown_def.set_device(shutdown_device_name); + MergeDebugInfo(NodeDebugInfo(shutdown_node_def), &shutdown_def); + + Status status; + *shutdown_node = graph->AddNode(shutdown_def, &status); + if (!status.ok()) { + return status; + } + (*shutdown_node)->set_assigned_device_name(shutdown_device_name); + // Replace the output control edges. + for (const DistributedTPURewriteHelpers::OutputDependency& dep : + output_dependencies) { + if (dep.dst_input != Graph::kControlSlot) { + return errors::Internal("Shutdown node had non-control edge output"); + } + graph->AddControlEdge(*shutdown_node, dep.dst); + } + return Status::OK(); +} + +Status AddHostDisconnectNode(const string& host_device_name, + const std::vector& input_dependencies, + Node* post_disconnect_node, int output_index, + Graph* graph) { + NodeDef host_disconnect_def; + host_disconnect_def.set_name(graph->NewName("disconnect_tpu_host")); + host_disconnect_def.set_op(kHostDisconnectOp); + host_disconnect_def.set_device(host_device_name); + MergeDebugInfo(NodeDebugInfo(post_disconnect_node->def()), + &host_disconnect_def); + + Status status; + Node* host_disconnect_node = graph->AddNode(host_disconnect_def, &status); + if (!status.ok()) { + return status; + } + host_disconnect_node->set_assigned_device_name(host_device_name); + // Replace the input control edges. + for (Node* src_node : input_dependencies) { + graph->AddControlEdge(src_node, host_disconnect_node); + } + if (output_index == -1) { + graph->AddControlEdge(host_disconnect_node, post_disconnect_node); + } else { + graph->AddEdge(host_disconnect_node, 0, post_disconnect_node, output_index); + } + return Status::OK(); +} + +} // namespace + +Status DistributedTPUConfigurationRewritePass::Run( + const GraphOptimizationPassOptions& options) { + VLOG(1) << "DistributedTPUConfigurationRewritePass::Run"; + + Graph* graph = options.graph->get(); + + if (VLOG_IS_ON(1)) { + DumpGraphToFile("distributed_tpu_configuration_before", *graph, + options.flib_def); + } + + // This pass can only run in the session master, which should fill + // in the device_set field to the options. + TF_RET_CHECK(options.device_set != nullptr); + + TF_RETURN_IF_ERROR( + DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType( + kConfigureOp, graph, *options.device_set, + [](const NodeDef& configuration_node_def, + const string& configuration_device_name, + const std::vector& host_devices, + const std::vector& input_dependencies, + const std::vector& + output_dependencies, + Graph* graph) -> Status { + const std::string& embedding_attr_string = GetNodeAttrString( + AttrSlice(configuration_node_def), kEmbeddingConfigurationAttr); + + if (!embedding_attr_string.empty()) { + return errors::InvalidArgument("embedding_config must be empty."); + } + + bool is_global_init = false; + bool enable_whole_mesh_compilations = false; + TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def, + "is_global_init", &is_global_init)); + TryGetNodeAttr(configuration_node_def, + "enable_whole_mesh_compilations", + &enable_whole_mesh_compilations); + TF_RETURN_IF_ERROR(SetTPUInitMode( + is_global_init ? TPUInitMode::kGlobal : TPUInitMode::kRegular)); + + bool compilation_failure_closes_chips; + TF_RETURN_IF_ERROR(GetNodeAttr(configuration_node_def, + "compilation_failure_closes_chips", + &compilation_failure_closes_chips)); + internal::SetTpuCompilationFailureClosesChips( + compilation_failure_closes_chips); + + // Add the global TPU system configuration node. + Node* configuration_node; + TF_RETURN_IF_ERROR(AddConfigurationNode( + configuration_device_name, host_devices.size(), graph, + enable_whole_mesh_compilations, &configuration_node)); + + // Add the host disconnect nodes. + for (int i = 0; i < host_devices.size(); ++i) { + const auto host_device = host_devices[i]; + TF_RETURN_IF_ERROR( + AddHostDisconnectNode(host_device->name(), input_dependencies, + configuration_node, i, graph)); + } + + // Add the host configuration nodes. + std::vector host_configuration_nodes; + for (const auto host_device : host_devices) { + Node* host_configuration_node; + TF_RETURN_IF_ERROR(AddHostConfigNode( + host_device->name(), configuration_node, graph, + enable_whole_mesh_compilations, &host_configuration_node)); + host_configuration_nodes.push_back(host_configuration_node); + } + + // Add the node to wait for the system configuration to + // stabilize. Use the name of the original dummy Op in case it was + // the target of a Session::Run call. + Node* wait_node; + TF_RETURN_IF_ERROR(AddWaitNode(configuration_device_name, + host_configuration_nodes, graph, + &wait_node)); + + // Add the nodes to set the global TPU ids at each host. + std::vector global_array_id_nodes; + for (const auto host_device : host_devices) { + Node* global_array_id_node; + TF_RETURN_IF_ERROR(AddGlobalTPUArrayNode(host_device->name(), + wait_node, graph, + &global_array_id_node)); + global_array_id_nodes.push_back(global_array_id_node); + } + + if (host_devices.empty()) { + return errors::InvalidArgument("TPU job contains no CPU devices"); + } + TF_RET_CHECK(!host_devices.empty()); + + TF_RETURN_IF_ERROR(AddSynchronizationNode( + configuration_node_def, host_devices.front()->name(), + global_array_id_nodes, wait_node, output_dependencies, graph)); + + return Status::OK(); + })); + + if (VLOG_IS_ON(1)) { + DumpGraphToFile("distributed_tpu_configuration_after", *graph, + options.flib_def); + } + + VLOG(1) << "DistributedTPUConfigurationRewritePass::Run() finished"; + return Status::OK(); +} + +Status DistributedTPUShutdownRewritePass::Run( + const GraphOptimizationPassOptions& options) { + VLOG(1) << "DistributedTPUShutdownRewritePass::Run"; + + Graph* graph = options.graph->get(); + + if (VLOG_IS_ON(1)) { + DumpGraphToFile("distributed_tpu_shutdown_before", *graph, + options.flib_def); + } + + // This pass can only run in the session master, which should fill + // in the device_set field to the options. + TF_RET_CHECK(options.device_set != nullptr); + + TF_RETURN_IF_ERROR( + DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType( + kShutdownOp, graph, *options.device_set, + [](const NodeDef& shutdown_node_def, + const string& shutdown_device_name, + const std::vector& host_devices, + const std::vector& input_dependencies, + const std::vector& + output_dependencies, + Graph* graph) -> Status { + Node* shutdown_node; + TF_RETURN_IF_ERROR( + AddShutdownNode(shutdown_node_def, shutdown_device_name, + output_dependencies, graph, &shutdown_node)); + + // Add the host disconnect nodes. + for (const auto host_device : host_devices) { + TF_RETURN_IF_ERROR( + AddHostDisconnectNode(host_device->name(), input_dependencies, + shutdown_node, -1, graph)); + } + + return Status::OK(); + })); + + if (VLOG_IS_ON(1)) { + DumpGraphToFile("distributed_tpu_shutdown_after", *graph, options.flib_def); + } + + VLOG(1) << "DistributedTPUShutdownRewritePass::Run() finished"; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h new file mode 100644 index 00000000000..191f32f9505 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Rewrites ConfigureDistributedTPU Op into a graph that configures each host. +// +// See the comment at the top of +// third_party/tensorflow/core/ops/tpu_configuration_ops.cc to see the +// sequence of Ops used to configure a distributed TPU system. + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +// Replaces dummy ConfigureDistributedTPU Ops assigned to TPU_SYSTEM +// devices with _ConfigureDistributedTPU and _WaitForDistributedTPU +// Ops on TPU_SYSTEM, and _InitializeHostForDistributedTPU on the CPU +// device of each host in the same job as the given TPU_SYSTEM device. +class DistributedTPUConfigurationRewritePass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +// Replaces dummy ShutdownDistributedTPU Ops assigned to TPU_SYSTEM +// devices with _ShutdownDistributedTPU Ops on TPU_SYSTEM and +// _DisconnectHostFromDistributedTPUSystem on the CPU device of each +// host in the same job as the given TPU_SYSTEM device. +class DistributedTPUShutdownRewritePass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_ diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_registration.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_registration.cc new file mode 100644 index 00000000000..db2b3a53f20 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_registration.cc @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h" + +namespace tensorflow { +namespace { + +// This pass removes the TPUEmbeddingConfiguration in ConfigureDistributedTPU. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20, + DistributedTPUConfigurationRewritePass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 20, + DistributedTPUShutdownRewritePass); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc new file mode 100644 index 00000000000..965a17481cb --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc @@ -0,0 +1,255 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for TPU rewrite passes. + +#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h" + +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// LINT.IfChange +Status DistributedTPURewriteHelpers::GetSystemDevice( + const string& system_spec_string, const DeviceSet& device_set, + DeviceNameUtils::ParsedName* system_spec, Device** system_device) { + if (!DeviceNameUtils::ParseFullName(system_spec_string, system_spec)) { + system_spec->Clear(); + } + + // Callers may have relied on an Op only being registered on TPU_SYSTEM + // devices to ensure the Op is placed there. Augment the device spec to make + // the device type explicit. + if (!system_spec->has_type || system_spec->type != DEVICE_TPU_SYSTEM) { + system_spec->type = DEVICE_TPU_SYSTEM; + system_spec->has_type = true; + system_spec->id = 0; + system_spec->has_id = true; + } + + std::vector system_devices; + device_set.FindMatchingDevices(*system_spec, &system_devices); + if (system_devices.empty()) { + if (system_spec_string.empty()) { + return errors::InvalidArgument( + "No TPU_SYSTEM device found. Please ensure that you're connected to " + "a host with a TPU_SYSTEM device."); + } + return errors::InvalidArgument("No matching devices found for '", + system_spec_string, "'"); + } else if (system_devices.size() > 1) { + // Validate that all system devices are part of the same job. + std::unordered_set job_names; + for (auto device : system_devices) { + const auto& parsed_name = device->parsed_name(); + TF_RET_CHECK(parsed_name.has_job); + job_names.insert(parsed_name.job); + } + if (job_names.size() > 1) { + return errors::InvalidArgument( + "System devices cannot be part " + "of multiple different jobs. Found: ", + str_util::Join(job_names, ",")); + } + + // Identify the lexicographically first device from the list of + // valid TPU SYSTEM devices, so that every process in the same + // 'cluster' definition uses the same system device. + std::sort(system_devices.begin(), system_devices.end(), + [](Device* i, Device* j) { + auto i_name = i->parsed_name(); + auto j_name = j->parsed_name(); + if (i_name.replica != j_name.replica) { + return i_name.replica < j_name.replica; + } + return i_name.task < j_name.task; + }); + } + + *system_device = system_devices[0]; + if (!DeviceNameUtils::ParseFullName((*system_device)->name(), system_spec)) { + return errors::InvalidArgument("Unable to re-parse system device name ", + (*system_device)->name(), + " as a device spec."); + } + return Status::OK(); +} +// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc) + +// LINT.IfChange +Status DistributedTPURewriteHelpers::GetHostSystemDevices( + const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set, + std::vector* host_system_devices) { + DeviceNameUtils::ParsedName host_spec; + if (system_spec.has_job) { + // The system Op has been explicitly assigned to a job, so we want + // all the hosts in that job. + CHECK(DeviceNameUtils::ParseFullName( + strings::StrCat("/job:", system_spec.job, "/device:", DEVICE_TPU_SYSTEM, + ":0"), + &host_spec)); + } else { + // The system Op has not been explicitly assigned to a + // job, so take all hosts in the system. There will be a runtime + // error if some of those hosts don't contain TPU devices. + CHECK(DeviceNameUtils::ParseFullName( + strings::StrCat("/device:", DEVICE_TPU_SYSTEM, ":0"), &host_spec)); + } + device_set.FindMatchingDevices(host_spec, host_system_devices); + + TF_RET_CHECK(!host_system_devices->empty()) + << "No hosts found matching device spec " + << DeviceNameUtils::ParsedNameToString(host_spec); + + // Check that all the devices belong to the same job. + TF_RET_CHECK((*host_system_devices)[0]->parsed_name().has_job); + const string& job_name = (*host_system_devices)[0]->parsed_name().job; + int replica = (*host_system_devices)[0]->parsed_name().replica; + for (const auto host_device : *host_system_devices) { + const auto& parsed_name = host_device->parsed_name(); + TF_RET_CHECK(parsed_name.has_job); + if (parsed_name.job != job_name) { + return errors::InvalidArgument( + "All TPU host devices must be in the same job"); + } + TF_RET_CHECK(parsed_name.has_replica); + if (parsed_name.replica != replica) { + return errors::InvalidArgument( + "All TPU host devices must be in the same replica"); + } + } + + // Sort the devices by replica and then task. + std::sort(host_system_devices->begin(), host_system_devices->end(), + [](Device* i, Device* j) { + auto i_name = i->parsed_name(); + auto j_name = j->parsed_name(); + return i_name.task < j_name.task; + }); + return Status::OK(); +} +// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc) + +// LINT.IfChange +Status DistributedTPURewriteHelpers::GetTPUDevices( + const DeviceNameUtils::ParsedName& system_spec, const DeviceSet& device_set, + int* num_tpus_per_host, std::vector>* tpu_devices) { + // GetHostSystemDevices returns the CPU device on each host that is + // going to be used for executing TPU code. + std::vector host_system_devices; + TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetHostSystemDevices( + system_spec, device_set, &host_system_devices)); + + // Enumerate all the physical devices. Enumerate devices on task 0, + // then task 1, etc. + std::sort(host_system_devices.begin(), host_system_devices.end(), + [](Device* i, Device* j) { + return i->parsed_name().task < j->parsed_name().task; + }); + + *num_tpus_per_host = 0; + tpu_devices->clear(); + tpu_devices->reserve(host_system_devices.size()); + for (const auto device : host_system_devices) { + // Make a copy of the parsed name because we are going to change it. + DeviceNameUtils::ParsedName device_spec = device->parsed_name(); + device_spec.has_type = true; + device_spec.type = "TPU"; + // Enumerate all the available TPUs. + device_spec.has_id = false; + std::vector host_tpu_devices; + device_set.FindMatchingDevices(device_spec, &host_tpu_devices); + // Sort the devices by device id. + std::sort(host_tpu_devices.begin(), host_tpu_devices.end(), + [](Device* i, Device* j) { + return i->parsed_name().id < j->parsed_name().id; + }); + if (tpu_devices->empty()) { + // First iteration: set *num_tpus_per_host to the number of TPUs on the + // first host. + *num_tpus_per_host = host_tpu_devices.size(); + } else if (*num_tpus_per_host != host_tpu_devices.size()) { + // Subsequent iterations: check the number of TPUs match the number on + // the first host. + return errors::InvalidArgument( + "Mismatched number of TPU devices in cluster ", *num_tpus_per_host, + " vs. ", host_tpu_devices.size()); + } + tpu_devices->push_back(std::move(host_tpu_devices)); + } + return Status::OK(); +} +// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc) + +Status DistributedTPURewriteHelpers::ForConfigurationNodeMatchingType( + const string& node_type, Graph* graph, const DeviceSet& device_set, + const std::function< + Status(const NodeDef& configuration_node_def, + const string& configuration_device_name, + const std::vector& host_devices, + const std::vector& input_dependencies, + const std::vector& output_dependencies, + Graph* graph)>& action) { + // Find all the matching nodes before mutating the graph. + std::vector nodes; + for (Node* node : graph->nodes()) { + if (node->type_string() == node_type) { + nodes.push_back(node); + } + } + + for (Node* node : nodes) { + string spec_string = node->requested_device(); + DeviceNameUtils::ParsedName spec; + Device* device; + TF_RETURN_IF_ERROR( + GetSystemDevice(spec_string, device_set, &spec, &device)); + const string& device_name = device->name(); + + std::vector host_devices; + TF_RETURN_IF_ERROR(GetHostSystemDevices(spec, device_set, &host_devices)); + + std::vector input_dependencies; + for (const Edge* edge : node->in_edges()) { + // Config ops have no inputs, so all edges must be control edges. + CHECK(edge->IsControlEdge()); + input_dependencies.push_back(edge->src()); + } + std::vector output_dependencies; + for (const Edge* edge : node->out_edges()) { + OutputDependency dep; + dep.src_output = edge->src_output(); + dep.dst = edge->dst(); + dep.dst_input = edge->dst_input(); + output_dependencies.push_back(dep); + } + NodeDef node_def = node->def(); + + // Remove the node now so we can insert a new node with the same + // name inside the action. + graph->RemoveNode(node); + + TF_RETURN_IF_ERROR(action(node_def, device_name, host_devices, + input_dependencies, output_dependencies, graph)); + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h new file mode 100644 index 00000000000..40aacceb5d5 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for TPU rewrite passes. + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_ + +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +class DistributedTPURewriteHelpers { + public: + // Given a user-assigned device string, system_spec_string, parse it into + // system_spec. Verify that the device type is either TPU_SYSTEM or + // unassigned, and in the latter case set it to TPU_SYSTEM:0. Having set the + // type, verify that the spec matches a unique device in device_set, and + // return that device in system_device. The normal use case is for + // system_spec_string to identify the TPU_SYSTEM on replica 0, task 0 of the + // job that contains the TPU hardware. + // TODO(b/110910013): Possibly remove the tpu system device. + static Status GetSystemDevice(const string& system_spec_string, + const DeviceSet& device_set, + DeviceNameUtils::ParsedName* system_spec, + Device** system_device); + + // Given a parsed system spec (e.g., the one returned above from + // GetSystemDeviceName), return in host_devices the TPU_SYSTEM:0 device on + // every host in the spec's job. If the spec does not include an explicit job, + // "localhost" is used. Returns an error if system_spec matches devices from + // a multiple jobs or replicas. + static Status GetHostSystemDevices( + const DeviceNameUtils::ParsedName& system_spec, + const DeviceSet& device_set, std::vector* host_system_devices); + + // Given a parsed system spec (e.g., the one returned above from + // GetSystemDeviceName), sets `*tpu_devices` to a per-host vector of the TPU + // devices on every host in the spec's job. If the spec does not include an + // explicit job, "localhost" is used. Sets `*num_tpus_per_host` to the number + // of TPU devices in each host, and verifies that each host in the job has + // the same number of TPU devices. + // Returns an error if system_spec matches devices from a multiple jobs or + // replicas. + static Status GetTPUDevices(const DeviceNameUtils::ParsedName& system_spec, + const DeviceSet& device_set, + int* num_tpus_per_host, + std::vector>* tpu_devices); + + // Perform 'action' on every node in 'graph' of type + // 'node_type'. This function is designed for use with configuration + // Ops that have no inputs or outputs. The arguments passed to 'action' are: + // 'configuration_node_name': the name of the node that matched + // 'configuration_device_name': the name of the device that the + // matching node is placed on + // 'host_devices': the set of TPU_SYSTEM devices on hosts with TPUs that are + // in the same system as the node that matched. + // 'input_dependencies': the set of nodes that have control edges to + // the matching node. + // 'output_dependencies': the set of output port, destination node, input port + // triples that have edges from the matching node. Input port is + // Graph::kControlSlot for a control edge. + // 'graph': the graph being mutated. + struct OutputDependency { + int src_output; + Node* dst; + int dst_input; + }; + static Status ForConfigurationNodeMatchingType( + const string& node_type, Graph* graph, const DeviceSet& device_set, + const std::function< + Status(const NodeDef& configuration_node_def, + const string& configuration_device_name, + const std::vector& host_devices, + const std::vector& input_dependencies, + const std::vector& output_dependencies, + Graph* graph)>& action); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_ diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD new file mode 100644 index 00000000000..0e5a91c961c --- /dev/null +++ b/tensorflow/core/tpu/kernels/BUILD @@ -0,0 +1,288 @@ +# TPU Kernel Implementations +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_proto_library_cc", +) + +package( + default_visibility = [ + "//tensorflow/core/tpu:__subpackages__", + "//tensorflow/stream_executor/tpu:__subpackages__", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tpu_compile_op_options", + srcs = ["tpu_compile_op_options.cc"], + hdrs = ["tpu_compile_op_options.h"], +) + +cc_library( + name = "tpu_configuration_ops", + srcs = ["tpu_configuration_ops.cc"], + hdrs = ["tpu_configuration_ops.h"], + deps = [ + ":tpu_mesh_state_interface", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_helper", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:refcount", + "//tensorflow/core/tpu:tpu_config_c_api", + "//tensorflow/core/tpu:tpu_configuration", + "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/core/tpu:tpu_library_loader", + "//tensorflow/stream_executor/tpu:proto_helper", + ], + alwayslink = 1, +) + +cc_library( + name = "tpu_compile_c_api_hdrs", + hdrs = ["tpu_compile_c_api.h"], + deps = [ + ":tpu_mesh_state_c_api", + "//tensorflow/c:tf_datatype", + "//tensorflow/stream_executor/tpu:proto_helper", + "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", + ], +) + +tf_proto_library_cc( + name = "tpu_executable_info_proto", + srcs = ["tpu_executable_info.proto"], + cc_api_version = 2, + protodeps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:protos_all", + ], +) + +tf_proto_library_cc( + name = "tpu_compile_proto", + srcs = ["tpu_compile.proto"], + cc_api_version = 2, + protodeps = [ + ":tpu_executable_info_proto", + "//tensorflow/compiler/tf2xla:host_compute_metadata_proto", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:protos_all", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto", + ], +) + +cc_library( + name = "tpu_compilation_cache_key", + srcs = [], + hdrs = [ + "tpu_compilation_cache_key.h", + ], + deps = ["@com_google_absl//absl/types:optional"], +) + +cc_library( + name = "tpu_compile_op_support", + srcs = ["tpu_compile_op_support.cc"], + hdrs = ["tpu_compile_op_support.h"], + deps = [ + ":tpu_compilation_cache_key", + ":tpu_compile_c_api_hdrs", + ":tpu_compile_proto_cc", + ":tpu_executable_info_proto_cc", + "//tensorflow/cc:ops", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:dump", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_module_group", + "//tensorflow/core:framework", + "//tensorflow/core/framework:protos_all_cc", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/stream_executor/tpu:proto_helper", + "//tensorflow/stream_executor/tpu:status_helper", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "tpu_compilation_cache_entry", + hdrs = [ + "tpu_compilation_cache_entry.h", + ], + deps = [ + ":tpu_executable_info_proto_cc", + ":tpu_program", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/core/lib/core:refcount", + ], +) + +cc_library( + name = "tpu_compilation_cache_lookup", + srcs = ["tpu_compilation_cache_lookup.cc"], + hdrs = [ + "tpu_compilation_cache_lookup.h", + ], + deps = [ + ":tpu_compilation_cache_entry", + ":tpu_compilation_cache_external", + ":tpu_compilation_cache_proto_cc", + "//tensorflow/core/lib/core:refcount", + "//tensorflow/core/platform:status", + "//tensorflow/core/profiler/lib:traceme", + ], +) + +cc_library( + name = "tpu_mesh_state_c_api", + hdrs = ["tpu_mesh_state_c_api.h"], +) + +cc_library( + name = "tpu_mesh_state_interface", + srcs = [], + hdrs = ["tpu_mesh_state_interface.h"], + deps = [ + ":tpu_compile_c_api_hdrs", + ":tpu_mesh_state_c_api", + "//tensorflow/compiler/xla/service", + "//tensorflow/core:framework", + "//tensorflow/core/platform:errors", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tpu:tpu_config_c_api", + ], +) + +cc_library( + name = "tpu_program", + srcs = ["tpu_program.cc"], + hdrs = ["tpu_program.h"], + deps = [ + ":tpu_compile_c_api_hdrs", + ":tpu_compile_op_support", + ":tpu_compile_proto_cc", + ":tpu_executable_info_proto_cc", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:hlo_module_group", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/core:lib", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/stream_executor/tpu:proto_helper", + "//tensorflow/stream_executor/tpu:status_helper", + "//tensorflow/stream_executor/tpu:tpu_platform_interface", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "tpu_compilation_cache_external", + srcs = ["tpu_compilation_cache_external.cc"], + hdrs = [ + "tpu_compilation_cache_external.h", + ], + deps = [ + ":tpu_compilation_cache_entry", + ":tpu_compilation_cache_key", + ":tpu_compilation_cache_metrics", # buildcleaner: keep + ":tpu_compilation_cache_metrics_hdrs", + ":tpu_compilation_cache_proto_cc", + ":tpu_compile_c_api_hdrs", + ":tpu_compile_op_support", + ":tpu_mesh_state_interface", + ":tpu_program", + ":tpu_util", + ":trace_util_hdrs", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform:refcount", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "tpu_compilation_cache_metrics_hdrs", + hdrs = ["tpu_compilation_cache_metrics.h"], + deps = [ + "//tensorflow/core/platform:types", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "tpu_compilation_cache_metrics", + srcs = ["tpu_compilation_cache_metrics.cc"], + deps = [ + ":tpu_compilation_cache_metrics_hdrs", + ], +) + +cc_library( + name = "trace_util_hdrs", + srcs = [], + hdrs = ["trace_util.h"], + deps = [ + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "tpu_util_hdrs", + srcs = [], + hdrs = ["tpu_util.h"], + deps = [ + ":tpu_compilation_cache_key", + "//tensorflow/cc:ops", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "tpu_util", + srcs = ["tpu_util.cc"], + hdrs = ["tpu_util.h"], + deps = [ + ":tpu_compilation_cache_key", + "//tensorflow/cc:ops", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:compile_only_client", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +tf_proto_library_cc( + name = "tpu_compilation_cache_proto", + srcs = ["tpu_compilation_cache.proto"], + cc_api_version = 2, +) diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto b/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto new file mode 100644 index 00000000000..8308cba128e --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +syntax = "proto3"; + +package tensorflow.tpu; + +// Target type for compilation cache fetch operation. +enum CompilationCacheFetchTarget { + INVALID = 0; + MAIN = 1; + SHARDING = 2; + UNSHARDING = 3; +} diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h new file mode 100644 index 00000000000..d16b2d521f6 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_ +#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_ + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_program.h" + +namespace tensorflow { +namespace tpu { + +class CompilationCacheEntry { + public: + explicit CompilationCacheEntry( + std::unique_ptr tpu_program) + : tpu_program_(std::move(tpu_program)) {} + + // Constructor for an empty entry. + CompilationCacheEntry() + : tpu_program_(nullptr) {} + + const TPUExecutableInfoProto* get_executable_info() const { + return &tpu_program_->executable_info(); + } + + const TPUHostTransferInfoProto* get_host_transfer_info() const { + return &tpu_program_->host_transfer_info(); + } + + const xla::HloProto* get_hlo_metadata() const { + return &tpu_program_->hlo_metadata(); + } + + // TODO(henrytan,jiawenhao): When should we expect more than one + // XLA_TpuProgram* per TpuProgram? Remove the program_count CHECK below then. + const XLA_TpuProgram* get_tpu_program() const { + CHECK_EQ(tpu_program_->program_count(), 1); + return tpu_program_->tpu_programs()[0]; + } + + private: + std::unique_ptr tpu_program_; +}; + +// Base class for a reference to a cached proto. A unique_ptr to a +// CompilationCacheEntryRef is returned by all the cache Lookup methods below, +// and ensures the underlying proto is not garbage-collected until the client +// discards the ptr. +class CompilationCacheEntryRef { + public: + virtual ~CompilationCacheEntryRef() = default; + + // Returns a CompilationCacheEntry that should not be used beyond the lifetime + // of the CompilationCacheEntryRef. + virtual CompilationCacheEntry get() = 0; +}; + +// Base class that holds references to compiled protos so that the protos are +// not garbage-collected before being used by execute ops. Use +// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete +// ref holder object. +class CompilationRefHolder : public ResourceBase { + public: + ~CompilationRefHolder() override = default; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_ENTRY_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc new file mode 100644 index 00000000000..8dbf60803cc --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc @@ -0,0 +1,791 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h" + +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/random.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/tpu/kernels/tpu_program.h" +#include "tensorflow/core/tpu/kernels/tpu_util.h" +#include "tensorflow/core/tpu/kernels/trace_util.h" + +namespace tensorflow { +namespace tpu { + +namespace { + +using CompilationEntry = TpuCompilationCacheInterface::CompilationEntry; + +int64 get_uid() { + uint64 unsigned_rand = random::New64() & INT64_MAX; + return static_cast(unsigned_rand); +} + +void PopulateEntry(const std::string& key, CompilationEntry* entry, + std::unique_ptr tpu_program) { + // Make the unique keys for each cached proto. + for (int i = 0; i < tpu_program->program_count(); ++i) { + entry->proto_key.push_back(ProtoKeyForComputation(key, i)); + } + + entry->tpu_program = std::move(tpu_program); + entry->initialized = true; +} + +std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) { + if (!key.has_guaranteed_const) { + return key.prefix; + } + return absl::StrCat(key.prefix, "|", key.session_handle, "|", + key.guaranteed_const_fingerprint()); +} + +// Return fingerprint_in_metadata if it's not empty; otherwise read input tensor +// data to compute the fingerprint. +std::string GuaranteedConstFingerprint( + const string& fingerprint_in_metadata, + const OpInputList& guaranteed_constants) { + if (fingerprint_in_metadata.empty()) { + uint64_t fingerprint = 0; + for (const auto& constant : guaranteed_constants) { + fingerprint = TpuCompile_CreateGuaranteedConstFingerprint( + fingerprint, constant.tensor_data().data(), + constant.tensor_data().size()); + } + return std::to_string(fingerprint); + } else { + return fingerprint_in_metadata; + } +} + +std::string CreateShapePrefix( + const std::vector& dynamic_shapes) { + std::string shapes_prefix; + for (const TensorShape& shape : dynamic_shapes) { + for (int64 size : shape.dim_sizes()) { + absl::StrAppend(&shapes_prefix, size, ","); + } + absl::StrAppend(&shapes_prefix, ";"); + } + return shapes_prefix; +} + +// Include compilation configurations of the arguments that are not captured +// by the called graph. +std::string CreateConfigPrefix(const TPUCompileMetadataProto& metadata) { + std::string config_prefix; + for (const auto& arg : metadata.args()) { + if (arg.is_same_data_across_replicas()) { + absl::StrAppend(&config_prefix, ":s"); + // Same. + } else { + // Different. + absl::StrAppend(&config_prefix, ":"); + } + if (arg.enable_xla_sharding() == + tpu::TPUCompileMetadataProto::Arg::ALLOWED) { + // Enabled. + absl::StrAppend(&config_prefix, "e"); + } + if (arg.unrestricted_layout()) { + // Unrestricted. + absl::StrAppend(&config_prefix, ":u"); + } + absl::StrAppend(&config_prefix, ",type(", arg.dtype(), ")"); + if (arg.has_shape()) { + absl::StrAppend(&config_prefix, ",shape("); + for (const auto& dim : arg.shape().dim()) { + absl::StrAppend(&config_prefix, dim.size(), ","); + } + absl::StrAppend(&config_prefix, ")"); + } + } + return config_prefix; +} + +} // namespace + +TpuCompilationCacheInterface::TpuCompilationCacheInterface( + int64_t max_cache_size) + : max_cache_size_(max_cache_size) { + if (max_cache_size < 0) { + LOG(FATAL) << "`max_cache_size` value must be greater than equal to 0"; + } + VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes."; +} + +TpuCompilationCacheInterface::~TpuCompilationCacheInterface() { + VLOG(1) << "TpuCompilationCacheInterface::~TpuCompilationCacheInterface()"; + // A buggy client may be holding onto a reference, or a client might have + // crashed while holding onto a reference. In either case, discard all + // outstanding client references to avoid leaking storage. + for (const auto& entry : entries_by_uid_) { + while (entry.second->external_references > 0) { + TF_CHECK_OK(Release(entry.first)); + } + } + while (!entries_by_last_use_.empty()) { + UnloadAndDestroy(MarkOldestEntryForEviction()); + } + // By the time the cache is deleted all reference holders should have already + // been deleted, since they were holding references to the cache. So all + // entries should be gone at this point. + CHECK_EQ(cache_store_.size(), 0); + CHECK_EQ(entries_by_uid_.size(), 0); + CHECK_EQ(entries_by_proto_key_.size(), 0); + CHECK_EQ(cache_size_, 0); + CHECK_EQ(marked_for_eviction_size_, 0); +} + +std::string TpuCompilationCacheInterface::FindCacheKey( + const TpuCompilationCacheKey& subgraph_key) const { + if (!subgraph_key.has_guaranteed_const) { + return subgraph_key.prefix; + } + auto iter = session_key_map_.find( + strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle)); + if (iter != session_key_map_.end()) { + return iter->second; + } + iter = fingerprint_key_map_.find(strings::StrCat( + subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint())); + if (iter != session_key_map_.end()) { + return iter->second; + } + VLOG(1) << "No matching cache key found for key " + << ConstructCompilationCacheKey(subgraph_key); + return ""; +} + +void TpuCompilationCacheInterface::InsertEntry( + const std::string& cache_key, const TpuCompilationCacheKey& subgraph_key, + CompilationEntry* entry) { + entry->parent = this; + entry->subgraph_key = cache_key; + entry->uid = get_uid(); + TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size()); + entry->cache_entry_debug_string = subgraph_key.prefix; + VLOG(1) << "Cache Initializing Entry Session Debug " + << entry->cache_entry_debug_string; + + if (!subgraph_key.has_guaranteed_const) { + return; + } + session_key_map_.insert(std::make_pair( + strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle), + cache_key)); + fingerprint_key_map_.insert(std::make_pair( + strings::StrCat(subgraph_key.prefix, + subgraph_key.guaranteed_const_fingerprint()), + cache_key)); +} + +CompilationEntry* TpuCompilationCacheInterface::InitializeEntry( + const string& key, + const std::function& initialize_program, + const TpuCompilationCacheKey& subgraph_key) { + CompilationEntry* main_entry = new CompilationEntry(); + + // Add the entry to the cache, with size zero since there are no compiled + // programs in it. Once the subgraph has been compiled, + // UpdateEntryAfterCompilation will be called to potentially mark old entries + // that don't fit any more for eviction. + // + // At this point there is one reference to entry, which is owned by the caller + // who created the entry. A second reference, owned by the cache, will be + // added below since we leave the entry in the 'marked for eviction' state + // here. + InsertEntry(key, subgraph_key, main_entry); + + // Initialize the programs outside the lock so that other cache operations + // can proceed during the (potentially lengthy) initialization. + Status initialization_status; + + auto tpu_program = absl::make_unique(); + { + mu_.Unlock(); + { + profiler::TraceMe compile_programs_traceme( + "TPU compilation cache compile", + /*level=*/2); + initialization_status = initialize_program(tpu_program.get()); + } + mu_.Lock(); + } + + main_entry->initialization_status = initialization_status; + + // Add the entry to the uid index. + auto uid_inserted = entries_by_uid_.insert( + std::pair(main_entry->uid, main_entry)); + CHECK(uid_inserted.second); + + if (initialization_status.ok()) { + // Compute the entries total size once all members are initialized. + main_entry->total_size = tpu_program->program_size(); + } + + // TODO(henrytan): handle sharding/unsharding. + PopulateEntry(key, main_entry, std::move(tpu_program)); + + for (int64 i = 0; i < main_entry->proto_key.size(); ++i) { + auto entry_inserted = entries_by_proto_key_.insert( + std::pair>( + main_entry->proto_key[i], std::make_pair(main_entry, i))); + CHECK(entry_inserted.second); + } + + // Add the size to marked_for_eviction_size_ since it will be adjusted down + // again when the newly-created entry gets unmarked. + marked_for_eviction_size_ += main_entry->total_size; + return main_entry; +} + +/*static*/ TpuCompilationCacheKey +TpuCompilationCacheInterface::CreateCompilationCacheKey( + absl::string_view function_name, uint64 function_library_fingerprint, + absl::string_view mlir_module, + const tensorflow::OpInputList& guaranteed_constants, + const std::vector& dynamic_shapes, + const tensorflow::tpu::TPUCompileMetadataProto& metadata, + const TpuMeshStateInterface& mesh_state) { + VLOG(1) << "FunctionLibraryFingerprint:" << function_library_fingerprint; + std::string shapes_prefix = CreateShapePrefix(dynamic_shapes); + VLOG(1) << "shapes_prefix = " << shapes_prefix; + std::string config_prefix = CreateConfigPrefix(metadata); + VLOG(1) << "config_prefix = " << config_prefix; + std::vector flattened_device_ids; + if (metadata.has_device_assignment()) { + for (const auto& device : + metadata.device_assignment().computation_devices()) { + flattened_device_ids.insert(flattened_device_ids.end(), + device.replica_device_ids().begin(), + device.replica_device_ids().end()); + } + } + // TODO(henrytan): return the debug_string. + const char* prefix = + TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty{ + config_prefix.data(), + shapes_prefix.data(), + function_name.data(), + mlir_module.data(), + flattened_device_ids.data(), + flattened_device_ids.size(), + guaranteed_constants.size(), + function_library_fingerprint, + metadata.num_cores_per_replica(), + metadata.num_replicas(), + mesh_state.data(), + }); + auto buffer_cleanup = gtl::MakeCleanup([prefix]() { delete[] prefix; }); + TpuCompilationCacheKey key; + key.prefix = prefix; + + // Guaranteed constants can be different across sessions. Use session_handle + // and guaranteed_const fingerprint to guarantee no collision. + if (guaranteed_constants.size() > 0) { + key.has_guaranteed_const = true; + key.session_handle = metadata.session_handle(); + // Both `metadata` and `guaranteed_constants` lifetime are captured by + // reference based on the assumption that these variables lifetime is + // managed through the `TPUCompileOpKernelImpl` that outlives the + // lifetime of the compilation cache lookups. + string fingerprint; + key.guaranteed_const_fingerprint = [&metadata, &guaranteed_constants, + fingerprint]() mutable { + if (fingerprint.empty()) { + fingerprint = GuaranteedConstFingerprint( + metadata.guaranteed_const_fingerprint(), guaranteed_constants); + } + return fingerprint; + }; + } + return key; +} + +TpuCompilationRefHolder* TpuCompilationCacheInterface::MakePerStepRefHolder() { + return new RefHolder(this); +} + +Status TpuCompilationCacheInterface::MarkEntryForEviction(int64 subgraph_uid) { + profiler::TraceMe key_release_traceme( + "TPU compilation cache possibly evict uid", + /*level=*/2); + CompilationEntry* deleted_entry = nullptr; + { + absl::MutexLock lock(&mu_); + auto iter = entries_by_uid_.find(subgraph_uid); + if (iter == entries_by_uid_.end()) { + // If already evicted, return ok. + return Status::OK(); + } + + // Mark entry for eviction. + CompilationEntry* subgraph_to_evict = iter->second; + // If there are external references, should not use this API. + if (subgraph_to_evict->external_references != 0) { + return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key, + " external_references greater than zero. Should " + "use TpuCompilationCache::Release."); + } + + VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key << " for eviction"; + entries_by_last_use_.erase(subgraph_to_evict->last_use); + cache_size_ -= subgraph_to_evict->total_size; + marked_for_eviction_size_ += subgraph_to_evict->total_size; + + // Evict if refcount exactly one, otherwise only discard cache's reference + // to the entry while the actual eviction will happen when refholder's + // references go away. + deleted_entry = DiscardEntryRef(subgraph_to_evict); + + VLOG(1) << "After possibly evicting entry " << subgraph_uid + << " refs cache is " << cache_store_.size() << " entries (" + << cache_size_ + marked_for_eviction_size_ + << " bytes), marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + } + + // Unload from device cache if entry is evicted from host cache. + UnloadAndDestroy(deleted_entry); + return Status::OK(); +} + +Status TpuCompilationCacheInterface::Release(int64 subgraph_uid) { + profiler::TraceMe key_release_traceme("TPU compilation cache release uid", + /*level=*/2); + + CompilationEntry* deleted_entry = nullptr; + { + absl::MutexLock lock(&mu_); + auto iter = entries_by_uid_.find(subgraph_uid); + + if (iter == entries_by_uid_.end()) { + return errors::NotFound("No cache entry found for uid ", subgraph_uid); + } + + CHECK_GT(iter->second->external_references, 0); + --iter->second->external_references; + + deleted_entry = DiscardEntryRef(iter->second); + + VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is " + << cache_store_.size() << " entries (" + << cache_size_ + marked_for_eviction_size_ + << " bytes), marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + } + UnloadAndDestroy(deleted_entry); + return Status::OK(); +} + +void TpuCompilationCacheInterface::UnloadAndDestroy(CompilationEntry* entry) { + if (!entry) return; + + CHECK(entry->RefCountIsOne()); + entry->tpu_program->UnloadAndDestroyPrograms(); + entry->Unref(); +} + +size_t TpuCompilationCacheInterface::RemoveEntry(const string& key) { + auto erased = cache_store_.erase(key); + TpuCompilationCacheMetrics::SetCacheEntryCount(cache_store_.size()); + auto parsed_key_or_status = ParseCompilationCacheKey(key); + CHECK(parsed_key_or_status.status().ok()); + const TpuCompilationCacheKey parsed_key = + parsed_key_or_status.ConsumeValueOrDie(); + if (!parsed_key.has_guaranteed_const) { + return erased; + } + session_key_map_.erase( + strings::StrCat(parsed_key.prefix, parsed_key.session_handle)); + fingerprint_key_map_.erase(strings::StrCat( + parsed_key.prefix, parsed_key.guaranteed_const_fingerprint())); + return erased; +} + +ABSL_MUST_USE_RESULT CompilationEntry* +TpuCompilationCacheInterface::DiscardEntryRef(CompilationEntry* entry) { + if (entry->RefCountIsOne()) { + // The last reference to this entry is going away, so really delete it from + // the cache in such a way that it can't be restored by being looked up + // again. + + // Sanity-check that it has been marked for eviction. + CHECK(entries_by_last_use_.find(entry->last_use) == + entries_by_last_use_.end()); + // Update the counter tracking how much space is taken up by entries that + // are marked for eviction. + marked_for_eviction_size_ -= entry->total_size; + + // Remove the entry from the cache. + auto erased = RemoveEntry(entry->subgraph_key); + + if (erased == 0) { + LOG(FATAL) << "Tried to discard nonexistent cache entry"; + } + erased = entries_by_uid_.erase(entry->uid); + CHECK_EQ(erased, 1); + for (const string& key : entry->proto_key) { + erased = entries_by_proto_key_.erase(key); + CHECK_EQ(erased, 1); + } + // The actual deletion will happen outside the lock in UnloadAndDestroy(). + return entry; + } + entry->Unref(); + return nullptr; +} + +void TpuCompilationCacheInterface::DiscardEntryRefs( + gtl::ArraySlice entries) { + std::vector removed_entries; + { + absl::MutexLock lock(&mu_); + + for (auto entry : entries) { + removed_entries.push_back(DiscardEntryRef(entry)); + } + + VLOG(1) << "After discarding entry refs cache is " << cache_store_.size() + << " entries (" << cache_size_ + marked_for_eviction_size_ + << " bytes), marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + } + for (auto removed_entry : removed_entries) { + UnloadAndDestroy(removed_entry); + } +} + +ABSL_MUST_USE_RESULT CompilationEntry* +TpuCompilationCacheInterface::MarkOldestEntryForEviction() { + CompilationEntry* entry_to_mark = entries_by_last_use_.begin()->second; + VLOG(1) << "Marking " << entry_to_mark->subgraph_key << " for eviction"; + entries_by_last_use_.erase(entry_to_mark->last_use); + cache_size_ -= entry_to_mark->total_size; + marked_for_eviction_size_ += entry_to_mark->total_size; + // Discard the cache's reference to entry. If steps are holding onto + // references to entry it won't be deleted until the last step holding it + // completes. It stays in the cache in the meantime and can be resurrected + // by a call to CompileIfKeyAbsent if that occurs before the last reference + // expires. + return DiscardEntryRef(entry_to_mark); +} + +void TpuCompilationCacheInterface::LookupEntryMarkedForEviction( + CompilationEntry* entry, std::vector* removed_entries) { + // The entry was previously marked for eviction (or is newly created) so + // unmark it. Add a reference (owned by the cache), update the cache size, and + // mark something old for eviction if necessary. + entry->Ref(); + marked_for_eviction_size_ -= entry->total_size; + cache_size_ += entry->total_size; + + // Mark the least-recently-used non-marked entry for eviction. Never mark the + // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1 + // which means there's only one entry not already marked for eviction), so + // that an entry persists in the cache even if it is larger than the allocated + // cache size. + while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) { + if (auto entry_to_evict = MarkOldestEntryForEviction()) { + removed_entries->push_back(entry_to_evict); + } + } +} + +Status TpuCompilationCacheInterface::ToSubEntryRef( + CompilationCacheEntryRef* entry, + CompilationCacheFetchTarget fetch_target) const { + return static_cast(entry)->ToSubEntryRef(fetch_target); +} + +TpuCompilationCacheInterface::EntryRefImpl::EntryRefImpl( + TpuCompilationCacheInterface* parent, CompilationEntry* entry, int index) + : parent_(parent), entry_(entry), index_(index) { + if (entry_ == nullptr) { + return; + } + if (entry_->main_entry == nullptr) { + entry_->Ref(); + } else { + // This is a sharding/unsharding entry nested in a main entry. Only refcount + // the main entry. + entry_->main_entry->Ref(); + } +} + +TpuCompilationCacheInterface::EntryRefImpl::~EntryRefImpl() { + if (entry_ == nullptr) { + return; + } + if (entry_->main_entry == nullptr) { + parent_->DiscardEntryRefs({entry_}); + } else { + parent_->DiscardEntryRefs({entry_->main_entry}); + } +} + +CompilationCacheEntry TpuCompilationCacheInterface::EntryRefImpl::get() { + if (entry_ == nullptr) { + // Create an empty entry if the entry is nullptr. This corresponds to + // non-existing sharding/unsharding entries. + return CompilationCacheEntry(); + } + return CompilationCacheEntry(std::move(entry_->tpu_program)); +} + +Status TpuCompilationCacheInterface::EntryRefImpl::ToSubEntryRef( + CompilationCacheFetchTarget fetch_target) { + CompilationEntry* target = nullptr; + switch (fetch_target) { + case CompilationCacheFetchTarget::MAIN: + target = entry_; + break; + case CompilationCacheFetchTarget::SHARDING: + target = entry_->sharding_entry.get(); + break; + case CompilationCacheFetchTarget::UNSHARDING: + target = entry_->unsharding_entry.get(); + break; + default: + return xla::InvalidArgument("Invalid fetch target: %d", fetch_target); + } + + if (target == nullptr) { + // Cache entry does not have an unsharding subentry. Unref and replace + // with nullptr. + parent_->DiscardEntryRefs({entry_}); + } + // Otherwise, since the refcount is always on the main entry, we don't need + // ref/unref. + entry_ = target; + return Status::OK(); +} + +Status TpuCompilationCacheInterface::Lookup( + int64 uid, int proto_index, + std::unique_ptr* entry) { + entry->reset(); + + profiler::TraceMe proto_lookup_traceme( + "TPU compilation cache proto lookup by uid", + /*level=*/2); + + absl::MutexLock lock(&mu_); + const auto iter = entries_by_uid_.find(uid); + if (iter == entries_by_uid_.end()) { + return errors::NotFound("No subgraph found for uid ", uid); + } + CompilationEntry* cache_entry = iter->second; + if (proto_index < 0 || + proto_index >= cache_entry->tpu_program->program_size()) { + return errors::NotFound("No proto found for core index ", proto_index, + " in subgraph with uid ", uid); + } + *entry = std::unique_ptr( + new EntryRefImpl(this, cache_entry, proto_index)); + return Status::OK(); +} + +Status TpuCompilationCacheInterface::Lookup( + const string& proto_key, std::unique_ptr* entry) { + entry->reset(); + + profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup", + /*level=*/2); + + absl::MutexLock lock(&mu_); + const auto iter = entries_by_proto_key_.find(proto_key); + if (iter == entries_by_proto_key_.end()) { + return errors::NotFound("No proto found for key ", proto_key); + } + CompilationEntry* cache_entry = iter->second.first; + int proto_index = iter->second.second; + *entry = std::unique_ptr( + new EntryRefImpl(this, cache_entry, proto_index)); + return Status::OK(); +} + +Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper( + const TpuCompilationCacheKey& subgraph_key, + const SessionMetadata* session_metadata, + TpuCompilationRefHolder* per_step_ref_holder, int64* uid, + std::vector* proto_key, std::vector* may_modify_variables, + std::vector* removed_entries, + std::vector>* hlo_metadata, + const std::function& compile_function) { + profiler::TraceMe subgraph_lookup_traceme( + "TPU compilation cache subgraph lookup", + /*level=*/2); + + // NOTE: In spite of the fact that we use MutexLock, we do not hold the lock + // for the lifetime of the object, see InitializeEntry() call below. + absl::MutexLock lock(&mu_); + + std::string cache_key = FindCacheKey(subgraph_key); + auto iter = cache_store_.find(cache_key); + bool is_new_key = iter == cache_store_.end(); + + const std::string session_name = SessionNameFromMetadata(session_metadata); + + CompilationEntry* entry = nullptr; + if (is_new_key) { + cache_key = ConstructCompilationCacheKey(subgraph_key); + TpuCompilationCacheMetrics::IncrementCacheLookupCount( + /*is_cache_hit=*/false, session_name); + const string msg = + strings::StrCat("TPU host compilation cache miss: cache_key(", + cache_key, "), session_name(", session_name, ")"); + + TRACESTRING(msg); + LOG(INFO) << msg; + + // Check if caller has disabled compilation. Set using + // internal::ScopedTpuCompileDisabler. + if (!IsTpuCompilationEnabled()) { + const string error_msg = strings::StrCat( + "[TpuCompilationDisabled]: Compilation cache miss, but compilation " + "disabled, session_name(", + session_name, ") Debug String: ", subgraph_key.debug_string); + if (VLOG_IS_ON(2)) { + VLOG(2) << "Cache Missed. Current cache entries: "; + for (auto it = cache_store_.begin(); it != cache_store_.end(); ++it) { + // TODO(henrytan): add DebugKey as cache_entry_debug_string to + // TpuCompilationCacheKey. + VLOG(2) << "Cache Debug Info: "; + VLOG(2) << it->second->cache_entry_debug_string; + } + } + + LOG_EVERY_N_SEC(WARNING, 30) << error_msg; + return errors::NotFound(error_msg); + } + + // The single ref on the newly-created entry is owned by the caller. + VLOG(1) << "Before adding new entry for key " << cache_key + << " with session_name( " << session_name << ");" + << "; cache is " << cache_store_.size() << " entries (" + << cache_size_ + marked_for_eviction_size_ << " bytes), " + << " marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + // Note that InitializeEntry() will Release/Reacquire mu_. + entry = InitializeEntry(cache_key, compile_function, subgraph_key); + TRACELITERAL("TPU host compilation cache: compilation done."); + + LOG(INFO) << strings::StrCat( + "TPU host compilation cache: compilation done for cache_key(", + cache_key, "), session_name(", session_name, ")"); + // If session_name is present, log some additional stats related to HBM + // here, so that they can be associated directly to the session. + if (!session_name.empty()) { + entry->tpu_program->LogProgramMemorySummary(); + } + } else { + TpuCompilationCacheMetrics::IncrementCacheLookupCount(true, session_name); + const string msg = + strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key, + "), session_name(", session_name, ")"); + TRACESTRING(msg); + VLOG(1) << msg; + VLOG(1) << "Before refreshing entry for key " << cache_key + << " with session_name( " << session_name << "); cache is " + << cache_store_.size() << " entries (" + << cache_size_ + marked_for_eviction_size_ << " bytes), " + << " marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + entry = iter->second; + // Make a new reference that is owned by the caller. + entry->Ref(); + // Block if necessary until the subgraph has been initialized. + mu_.Await(absl::Condition( + +[](CompilationEntry* e) { return e->initialized; }, entry)); + } + + // Let the caller know the uid of the entry. + *uid = entry->uid; + // Let the caller know the keys for each of the cached protos. + *proto_key = entry->proto_key; + *may_modify_variables = entry->tpu_program->may_modify_variables(); + *hlo_metadata = entry->hlo_metadata; + + // If the caller didn't supply a per_step_ref_holder then the caller is going + // to manually release the reference later via a call to Release(). + if (per_step_ref_holder == nullptr) { + ++entry->external_references; + } else { + // The caller wants its reference to be handed off to a per-step holder that + // will discard the reference when the step completes. + RefHolder* cast_ref_holder = static_cast(per_step_ref_holder); + TF_RET_CHECK(cast_ref_holder != nullptr); + cast_ref_holder->AddRef(entry); + } + + // Remove the old LRU-table entry if it wasn't already marked for eviction. + auto erased = entries_by_last_use_.erase(entry->last_use); + // Update the LRU table indicating this entry is the most recently used. + entry->last_use = use_counter_++; + entries_by_last_use_[entry->last_use] = entry; + if (erased == 0) { + // The entry had been marked for eviction, or is newly created. + LookupEntryMarkedForEviction(entry, removed_entries); + } + + // Log a little more verbosely when a key is added. + if (VLOG_IS_ON(1) || is_new_key) { + LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing") + << " entry for key " << cache_key << " with session_name " + << session_name << " cache is " << cache_store_.size() + << " entries (" << cache_size_ + marked_for_eviction_size_ + << " bytes), " + << " marked for eviction " + << (cache_store_.size() - entries_by_last_use_.size()) + << " entries (" << marked_for_eviction_size_ << " bytes)."; + } + return entry->initialization_status; +} + +tensorflow::Status TpuCompilationCacheInterface::CompileIfKeyAbsent( + const TpuCompilationCacheKey& cache_key, + const tensorflow::SessionMetadata* session_metadata, + TpuCompilationRefHolder* per_step_ref_holder, int64* uid, + std::vector* proto_key, std::vector* may_modify_variables, + std::vector>* hlo_metadata, + const std::function& compile_function) { + std::vector removed_entries; + auto status = CompileIfKeyAbsentHelper( + cache_key, session_metadata, per_step_ref_holder, uid, proto_key, + may_modify_variables, &removed_entries, hlo_metadata, compile_function); + for (auto entry : removed_entries) { + UnloadAndDestroy(entry); + } + return status; +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h new file mode 100644 index 00000000000..b6cdbe9fa0b --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h @@ -0,0 +1,394 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_ +#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_ + +#include +#include +#include +#include + +#include "absl/container/node_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_program.h" + +namespace tensorflow { +namespace tpu { + +const char kCompilationCacheResourceName[] = "tpu_compilation_cache"; +const char kCompilationCacheUnloaderResourceName[] = + "tpu_compilation_cache_unloader"; + +// Base class that holds references to compiled protos so that the protos are +// not garbage-collected before being used by execute ops. Use +// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete +// ref holder object. +class TpuCompilationRefHolder : public ResourceBase { + public: + ~TpuCompilationRefHolder() override = default; +}; + +class TpuCompilationCacheInterface : public ResourceBase { + public: + using Status = ::stream_executor::port::Status; + + // An entry in the compilation cache. The entry is deleted once it has been + // marked for eviction from the cache _and_ all steps that use it have + // completed. When the entry is first created, it is uninitialized and a + // client-supplied compilation function is run outside the cache's lock to + // generate the programs to be stored in the entry. Any other client that + // requests the entry will block until it has been initialized. Each entry has + // a last_use value that set from a monotonically-increasing counter in the + // cache whenever the entry is referenced. When the cache becomes full, + // entries are marked for eviction in LRU order. + // + // The bridge can request XLA to generate separate sharding and unsharding + // programs along with the main program; we use nested fields sharding_entry, + // unsharding_entry to store them under the main entry, and these two fields + // must be either both present or both absent. They have a back pointer + // main_entry to refer to the main program. These nested entries share the + // same cache key and the same lifetime as the main entry, so we use the + // refcount on the main entry to track the access to any of them. + // /-------------------------------\ + // v \ + // main_entry (refcount) -> sharding_entry -> main_entry + // ^ \ + // | \-> unsharding_entry -> main_entry + // \--------------------------------------/ + struct CompilationEntry : public core::RefCounted { + TpuCompilationCacheInterface* parent = nullptr; // Not owned. + bool initialized = false; + + // The Status returned by the compilation function when the entry is + // initialized. This status will be returned to any client that requests the + // entry. + Status initialization_status; + + // The uid describing this entry. + int64 uid; + std::vector proto_key; + + // Counter to keep track of LRU entries for the eviction policy. + int64 last_use = -1; + + // The unique key describing this entry. + std::string subgraph_key; + + // Entries representing the associated sharding and unsharding programs, + // which share the same life time of the owning main entry, so we always use + // the main entry's ref count. + std::unique_ptr sharding_entry; + std::unique_ptr unsharding_entry; + + // The number of 'external' client-held references to the entry. + int external_references = 0; + + std::vector> hlo_metadata; + + // The sum of the SpaceUsed of each of the elements of programs; an estimate + // of how much RAM the entry consumes, used to determine when entries must + // be marked for eviction. + int64 total_size = 0; + + // Only used for the nested sharding/unsharding entries to point to the + // owning main entry. + CompilationEntry* main_entry = nullptr; + + // Debug info in case we miss. + string cache_entry_debug_string; + + // Compiled Tpu program. + std::unique_ptr tpu_program; + }; + + explicit TpuCompilationCacheInterface(int64_t max_cache_size); + ~TpuCompilationCacheInterface() override; + TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete; + TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) + = delete; + + Status CompileIfKeyAbsent( + const TpuCompilationCacheKey& cache_key, + const SessionMetadata* session_metadata, + TpuCompilationRefHolder* per_step_ref_holder, int64* uid, + std::vector* proto_key, std::vector* may_modify_variables, + std::vector>* hlo_metadata, + const std::function& compile_function); + + static TpuCompilationCacheKey CreateCompilationCacheKey( + absl::string_view function_name, uint64 function_library_fingerprint, + absl::string_view mlir_module, + const tensorflow::OpInputList& guaranteed_constants, + const std::vector& dynamic_shapes, + const tensorflow::tpu::TPUCompileMetadataProto& metadata, + const TpuMeshStateInterface& mesh_state); + + string DebugString() const override { return "TpuCompilationCacheInterface"; } + + // Makes a reference holder for this cache, that can be stored in the per-step + // resource manager and will ensure that compiled entries persist until the + // end of a step. + TpuCompilationRefHolder* MakePerStepRefHolder(); + + // Differences between MarkEntryForEviction and Release: + // There are two modes of managing cache entries: + // 1) LRU eviction + pinning; 2) manual. + // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent. + // Otherwise it is manual mode (mainly used by XRT). + // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache + // entries when callers know that they do not need them anymore. + // Release should only be used in mode 2) to explicitly remove an entry. + + // Mark the entry indexed by `subgraph_uid` for eviction. This should only be + // called if per_step_ref_holder was NOT nullptr in the corresponding call to + // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64 + // subgraph_uid). + Status MarkEntryForEviction(int64 subgraph_uid); + + // Manually discards a reference to the compiled subgraph. This should only be + // called if per_step_ref_holder was nullptr in the corresponding call to + // CompileIfKeyAbsent(subgraph_key, ...). + Status Release(int64 subgraph_uid); + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by key. On success a pointer to an EntryRef + // holding the program is returned in entry. + Status Lookup(const string& proto_key, + std::unique_ptr* entry); + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by uid. On success a pointer to an EntryRef + // holding the program is returned in entry. + Status Lookup(int64 uid, int proto_index, + std::unique_ptr* entry); + + // Mutates the main entry ref to point to the entry's subentry + // (for sharding/unsharding) or main entry (unchanged) representing the + // fetch target. The entry ref needs to point to the main entry before this + // call. + // + // If the requested subentry does not exist, the ref will point to a nullptr + // entry. + Status ToSubEntryRef(CompilationCacheEntryRef* entry, + CompilationCacheFetchTarget fetch_target) const; + + private: + // Wrapper for a cache entry that holds a reference to the entry until the + // wrapper is deleted. This wrapper is the concrete type of + // CompilationCacheEntryRef returned by Lookup. + class EntryRefImpl : public CompilationCacheEntryRef { + public: + EntryRefImpl(TpuCompilationCacheInterface* parent, CompilationEntry* entry, + int index); + ~EntryRefImpl() override; + + CompilationCacheEntry get() override; + + // Mutates this ref to point to the entry's subentry (for + // sharding/unsharding) or main entry (unchanged) as specified by + // fetch_target. The refcount is kept unchanged, since we only track the + // refcount of the main entry. The entry ref needs to point to the main + // entry before this call. + // + // If the requested subentry does not exist, the ref will point to a nullptr + // entry, and the original entry will be unref'ed. + Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target); + + private: + TpuCompilationCacheInterface* parent_; // Not owned. + // A reference to entry_ is acquired in the constructor and released via + // parent->DiscardEntryRefs in the destructor. + CompilationEntry* entry_; + // The program in entry_ that is returned by the get method. + int index_; + }; + + // Private implementation of the generic CompilationRefHolder that knows about + // CompiledSubgraph entries. + class RefHolder : public TpuCompilationRefHolder { + public: + explicit RefHolder(TpuCompilationCacheInterface* parent) : parent_(parent) { + parent_->Ref(); + } + ~RefHolder() override { + // Release our reference to the parent. + parent_->Unref(); + } + + // Adds entry to the list of entries that will be released when the + // RefHolder is destroyed. Each entry is released via a call to + // parent_->DiscardEntryRefs. + void AddRef(CompilationEntry* entry) { + entries_.push_back(entry); + } + + string DebugString() const override { + return "TpuCompilationCacheInterface::RefHolder"; + } + + private: + TpuCompilationCacheInterface* parent_; // Not owned. + std::vector entries_; + }; + + // The bulk of implementation of CompileIfKeyAbsent() with the exception + // of unloading programs that corresponds to possibly removed cache + // entries. The split helps to manage locking since we prefer to perform + // unloading without holding extra locks. + Status CompileIfKeyAbsentHelper( + const TpuCompilationCacheKey& subgraph_key, + const SessionMetadata* session_metadata, + TpuCompilationRefHolder* per_step_ref_holder, int64* uid, + std::vector* proto_key, std::vector* may_modify_variables, + std::vector* removed_entries, + std::vector>* hlo_metadata, + const std::function& compile_function); + + // This is called by the cache when entry is marked for eviction; by + // a RefHolder (via DiscardEntryRefs) when a step completes; and by + // an EntryRefImpl when it is destroyed. Releases one reference to entry + // if more than 1 remains. If only one reference is left, the entry is removed + // from cache_ and is returned to the caller; which must eventually call + // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef + // to avoid holding the lock during program unloading. + ABSL_MUST_USE_RESULT CompilationEntry* DiscardEntryRef( + CompilationEntry* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Convenience method called by ~RefHolder without mu_ held. Calls + // DiscardEntryRef on every element of entries. + void DiscardEntryRefs( + gtl::ArraySlice entries); + + // Marks the oldest unmarked entry for eviction. Requires that there is at + // least one such entry. In case the evicted entry had only 1 reference it + // is removed from the cache and returned to the caller which must eventually + // call UnloadAndDestroy. + CompilationEntry* MarkOldestEntryForEviction() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Updates datastructures to indicate that entry, which had been marked for + // eviction, has been looked up. This is called by CompileIfKeyAbsent when an + // entry is newly created, or an entry that has been marked for eviction but + // not yet evicted is looked up. + // + // First the entry is unmarked for eviction, i.e. the cache gains a reference + // to entry, entry's last_use field is set to be the most recent value of + // use_counter_ and entries_by_last_use_ is updated accordingly. + // + // Next, the size of the cache is examined to see if any other entries need to + // be marked for eviction now that entry has been unmarked. While the total + // size of unmarked cached entries is greater than max_cache_size_, entries + // are marked for eviction in LRU order. The most recently used entry is never + // marked for eviction, so an entry larger than the max cache size will remain + // in the cache until it is replaced by something else. In case some entries + // actually were removed from the cache, they are a returned to the caller via + // removed_entries. The caller must eventually delete them by calling + // UnloadAndDestroy. + void LookupEntryMarkedForEviction( + CompilationEntry* entry, std::vector* removed_entries) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Removes the entry with given key from cache. + size_t RemoveEntry(const string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Inserts the given key and entry to cache. + void InsertEntry(const std::string& key, + const TpuCompilationCacheKey& subgraph_key, + CompilationEntry* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the cache key matching given subgraph_key. + std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Creates a new entry by running initialize_programs and places it in the + // cache to be looked up by key. The new entry is in the 'marked for eviction' + // state (not present in entries_by_last_use_) and the caller is expected to + // call LookupEntryMarkedForEviction after InitializeEntry. + // + // **InitializeEntry releases mu_ during the call to initialize_programs.** + CompilationEntry* InitializeEntry( + const string& key, + const std::function& initialize_program, + const TpuCompilationCacheKey& subgraph_key) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Unloads the program associated with the entry from all local devices + // and deletes the entry itself. It is assumed no one else has a reference + // to it and all related keys had already been removed from the cache. + // The call can perform device IO so no locks should be held while calling it. + void UnloadAndDestroy(CompilationEntry* entry) ABSL_LOCKS_EXCLUDED(mu_); + + // The maximum size of entries that are stored in the cache before entries are + // marked for eviction. + const int64 max_cache_size_; + + mutable absl::Mutex mu_; + // The total size of entries that are stored and not marked for eviction. + int64 cache_size_ ABSL_GUARDED_BY(mu_) = 0; + + // The total size of entries that are marked for eviction. + int64 marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0; + + // The value to assign to the last_use field of the next entry that is looked + // up. + int64 use_counter_ ABSL_GUARDED_BY(mu_) = 0; + + // session_key_map_ and fingerprint_key_map_ are used for looking up the + // cache_ key matching a given subgraph key. When doing a lookup, check + // session_key_map_ first to avoid unnecessay fingerprint computation. + // Map from key prefix + session_handle to a cache_ key. + std::unordered_map session_key_map_ ABSL_GUARDED_BY(mu_); + + // Map from key prefix + fingerprint to a cache_ key. + std::unordered_map fingerprint_key_map_ ABSL_GUARDED_BY(mu_); + + // All the subgraph entries that can be looked up in the cache. An entry is + // marked for eviction iff it is present in cache_ and not in + // entries_by_last_use_. + std::unordered_map cache_store_ + ABSL_GUARDED_BY(mu_); + + // All the subgraph entries that can be looked up in the cache, indexed by + // uid. + absl::node_hash_map entries_by_uid_ + ABSL_GUARDED_BY(mu_); + + // All the protos that can be looked up in the cache, indexed by proto + // key. The value of the map is a subgraph and the index of the proto compiled + // for that subgraph. + std::unordered_map> + entries_by_proto_key_ ABSL_GUARDED_BY(mu_); + + // Map from last_use to entry, used to mark entries for eviction in LRU + // order. If an entry's last_use counter is not present as a key in + // entries_by_last_use_ then the entry has been marked for eviction. + std::map entries_by_last_use_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_INTERFACE_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h new file mode 100644 index 00000000000..49c2eb64944 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_ + +#include +#include + +#include "absl/types/optional.h" + +namespace tensorflow { +namespace tpu { + +struct TpuCompilationCacheKey { + // Prefix of the key. + std::string prefix; + + // A boolean flag to specify if `guaranteed_const` is used. Guarantee const is + // normally used in TPU inference to avoid re-copying unchanged variables onto + // the TPU device. It promises the value is identical for every execution in + // the same session even if the actual value changes in later executions. + bool has_guaranteed_const = false; + + // Unique session identifier. It is set when `has_guaranteed_const` is true. + std::string session_handle; + + // Fingerprint of `guaranteed_const` value. It is set when the value of the + // `has_guaranteed_const` is true. Produce the value when necessary. + std::function guaranteed_const_fingerprint; + + // A more verbose key for debugging purpose. + std::string debug_string; + + explicit TpuCompilationCacheKey() {} + explicit TpuCompilationCacheKey(const std::string& p) : prefix(p) {} +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc new file mode 100644 index 00000000000..f4f8dbfc80f --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.cc @@ -0,0 +1,93 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" + +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h" + +namespace tensorflow { +namespace tpu { + +namespace { +class CompilationCacheFetchTargetUtility { + public: + CompilationCacheFetchTargetUtility() + : names_({"Invalid", "Main", "Sharding", "Unsharding"}) {} + + std::string name(CompilationCacheFetchTarget target) const { + return names_[static_cast(target)]; + } + + private: + const std::vector names_; +}; + +std::string GetName(CompilationCacheFetchTarget target) { + static const auto* util = new CompilationCacheFetchTargetUtility(); + return util->name(target); +} + +} // namespace + +TpuCompilationCacheLocalLookup::TpuCompilationCacheLocalLookup( + TpuCompilationCacheInterface* cache) + : cache_(cache) {} + +TpuCompilationCacheLocalLookup::~TpuCompilationCacheLocalLookup() { + cache_->Unref(); +} + +Status TpuCompilationCacheLocalLookup::Lookup( + const string& proto_key, std::unique_ptr* entry, + CompilationCacheFetchTarget fetch_target) { + profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup", + /*level=*/2); + Status s = cache_->Lookup(proto_key, entry); + VLOG(1) << "Looked up key " << proto_key << " in local subgraph cache status " + << s; + if (!s.ok()) { + return s; + } + s = cache_->ToSubEntryRef(entry->get(), fetch_target); + + VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status " + << s; + return s; +} + +Status TpuCompilationCacheLocalLookup::Lookup( + int64 uid, int proto_index, + std::unique_ptr* entry, + CompilationCacheFetchTarget fetch_target) { + profiler::TraceMe proto_lookup_traceme("Local TPU proto cache lookup by uid", + /*level=*/2); + Status s = cache_->Lookup(uid, proto_index, entry); + VLOG(1) << "Looked up uid " << uid << ", index " << proto_index + << " in local subgraph cache status " << s; + if (!s.ok()) { + return s; + } + s = cache_->ToSubEntryRef(entry->get(), fetch_target); + VLOG(1) << "Fetched subentry: " << GetName(fetch_target) << " with status " + << s; + return s; +} + +string TpuCompilationCacheLocalLookup::DebugString() const { + return "TpuCompilationCacheLocalLookup"; +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h new file mode 100644 index 00000000000..138777a438c --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_ +#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_ + +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" + +namespace tensorflow { +namespace tpu { + +// Base class allowing Execute Ops to look up ISA protos. Different subclasses +// are used when the execute Op is in the same address space as the compile Op, +// and when they need to communicate over RPC. +class TpuCompilationCacheLookup : public ResourceBase { + public: + ~TpuCompilationCacheLookup() override = default; + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by key. On success a wrapper for the proto is + // returned in program. The wrapper is guaranteed to be valid only during the + // execution of the Op requesting the proto. + // + // Only one of the main, sharding, unsharding entries is fetched, as specified + // in fetch_target. + // + // If the compilation does not create sharding/unsharding programs, but the + // fetch_target requests one of them, then after this call + // (*entry)->get().get_executable() will return nullptr. + virtual Status Lookup(const string& proto_key, + std::unique_ptr* entry, + CompilationCacheFetchTarget fetch_target) = 0; + + virtual Status Lookup(const string& proto_key, + std::unique_ptr* entry) { + return Lookup(proto_key, std::move(entry), + CompilationCacheFetchTarget::MAIN); + } + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by uid. On success a wrapper for the proto is + // returned in program. The wrapper is guaranteed to be valid only during the + // execution of the Op requesting the proto. + virtual Status Lookup(int64 uid, int proto_index, + std::unique_ptr* entry, + CompilationCacheFetchTarget fetch_target) = 0; + + virtual Status Lookup(int64 uid, int proto_index, + std::unique_ptr* entry) { + return Lookup(uid, proto_index, std::move(entry), + CompilationCacheFetchTarget::MAIN); + } +}; + +// Forward declaration to break cycle dependency graph. +class TpuCompilationCacheInterface; + +// Class for looking up ISA protos when the execute and compile Op are in the +// same address space. The proto is simply looked up in the compilation cache, +// without any serialization taking place. +class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup { + public: + explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache); + ~TpuCompilationCacheLocalLookup() override; + + Status Lookup(const string& proto_key, + std::unique_ptr* entry, + CompilationCacheFetchTarget fetch_target) override; + + Status Lookup(int64 uid, int proto_index, + std::unique_ptr* entry, + CompilationCacheFetchTarget fetch_target) override; + + string DebugString() const override; + + private: + // The subgraph compilation cache, in the same process address space where the + // lookups are happening. + TpuCompilationCacheInterface* cache_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_COMPILATION_CACHE_LOOKUP_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc new file mode 100644 index 00000000000..ba4e2ccff93 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.cc @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h" + +namespace tensorflow { +namespace tpu { + +/* static */ +void TpuCompilationCacheMetrics::IncrementCacheLookupCount( + bool is_cache_hit, absl::string_view session_name) { + // A placeholder for tracking metrics. +} + +/* static */ +void TpuCompilationCacheMetrics::SetCacheEntryCount(int64 count) { + // A placeholder for tracking metrics. +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h new file mode 100644 index 00000000000..e30a7a4c013 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_metrics.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tpu { + +// Tracks Tpu compilation cache metrics. +class TpuCompilationCacheMetrics { + public: + // Increments the number of cache lookup count. + static void IncrementCacheLookupCount(bool is_cache_hit, + absl::string_view session_name); + + // Sets the total count of cache entries. + static void SetCacheEntryCount(int64 count); +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_METRICS_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compile.proto b/tensorflow/core/tpu/kernels/tpu_compile.proto new file mode 100644 index 00000000000..5b70de67a05 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compile.proto @@ -0,0 +1,144 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +syntax = "proto3"; + +package tensorflow.tpu; + +import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/protobuf/tpu/compile_metadata.proto"; +import "tensorflow/core/tpu/kernels/tpu_executable_info.proto"; + +message PerCoreVariableIndices { + // For each resource variable output, what was the index of the corresponding + // input and was it updated? The indices are sorted by input order. + repeated TPUExecutableInfoProto.UpdateIndexPair variable_indices = 1; +} + +message PerCoreArgShapes { + // Argument shapes for each Tpu core. + repeated xla.ShapeProto shapes = 1; +} + +message PerCoreOutputShapes { + // Output shapes for each Tpu core. + repeated xla.ShapeProto shapes = 1; +} + +message OutputDescriptionProto { + // Type and shape of the output. The shape is the unflattened shape. + // When `type` is DT_RESOURCE, `shape` is the shape of the resource + // variable's value. + tensorflow.DataType type = 1; + tensorflow.TensorShapeProto shape = 2; + + // Constant output value, if known to be constant at JIT compilation time. + // 'Tensor' is in host memory. + bool is_constant = 3; + tensorflow.TensorProto constant_value = 4; + + // When this output is a resource, i.e. `type == DT_RESOURCE`, this is + // the index of the input that contains the resource. + int32 input_index = 5; + + // Whether this output is a TensorList. + bool is_tensor_list = 6; +} + +// Describes a variable write side effect of the computation. +message ResourceUpdateProto { + // Index of the input that contains the variable resource to write to. + int32 input_index = 1; + + // Type and shape of the tensor to be written back. + // The `shape` field has the same meaning as the Argument::shape field. + tensorflow.DataType type = 2; + tensorflow.TensorShapeProto shape = 3; + + // Was the value of the variable modified by the computation? + // (Always true, unless `return_updated_values_for_all_resources` is true.) + bool modified = 4; + + // If the resource is a TensorArray, the set of gradients read or written. + map tensor_array_gradients_accessed = 5; +} + +// Describes the result of a XLA Compiler compilation. +message XlaCompilationResultProto { + // Vector that maps from the parameters of the XLA computation to their + // original argument positions. To handle compile-time constant inputs, the + // parameters to the XLA computation may be a subset of the original + // arguments. The relative ordering of parameters are maintained. + repeated int32 input_mappings = 1; + + // Input shapes of the computation. If we are flattening inputs, these are + // the flattened shapes. + repeated xla.ShapeProto xla_input_shapes = 2; + + // Output shape in XLA format. The output shape is always a tuple. If we + // are flattening outputs, these are the flattened shapes. + xla.ShapeProto xla_output_shape = 3; + + // TensorFlow shapes of outputs, together with the values of any + // constant arguments. Vector indexed by Tensorflow _Retval number, + // containing both constant and non-constant results. + repeated OutputDescriptionProto outputs = 4; + + // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their + // matching RecvAtHost/SendFromHost Ops in the outer graph. + tf2xla.HostComputeMetadata host_compute_metadata = 5; + + // Resources whose values were updated by the computation, ordered + // by return value position (which is the same as the order the resources + // were passed as arguments). Resource updates follow the non-constant + // results in the outputs of XLA computation. + repeated ResourceUpdateProto resource_updates = 6; + + // The XLA computation built from the tensorflow subgraph. + xla.HloModuleProto computation = 7; +} + +// TpuAotCompilationRequestProto represents a compilation request for performing +// ahead-of-time (AOT) compilation of XLA Computations into XLA HLO IR. +message TpuAotCompilationRequestProto { + // A set of HLO module built to run concurrently + // across different devices. + xla.HloModuleGroupProto hlo_module_group = 1; + + // Compilation metadata. + TPUCompileMetadataProto metadata = 2; + + // DeviceAssignmentProto is a serialized form of DeviceAssignment class, which + // represents the device ids assigned to a set of replicated computations. + // See xla::DeviceAssignment class comment for more details. + xla.DeviceAssignmentProto device_assignment = 3; + + // Per TPU core program arguments shapes. + repeated PerCoreArgShapes per_core_arg_shapes = 4; + + // Per TPU core program outputs shapes. + repeated PerCoreOutputShapes per_core_output_shapes = 5; + + // Per TPU core information containing what was the index of the corresponding + // input and if whether it was updated. The indices are sorted by input order. + repeated PerCoreVariableIndices per_core_variable_indices = 6; + + // XLA compiler compilation result. + XlaCompilationResultProto compilation_result = 7; +} diff --git a/tensorflow/core/tpu/kernels/tpu_compile_c_api.h b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h new file mode 100644 index 00000000000..53e79aa51b0 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compile_c_api.h @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_ + +#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" +#include "tensorflow/stream_executor/tpu/proto_helper.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" + +enum TpuCoreTypeEnum { + kTensorCore, + kEmbeddingV1, + kEmbeddingV2, +}; + +typedef struct XLA_TpuProgram XLA_TpuProgram; + +// Property for creating compilation cache key. +struct CompilationCacheKeyProperty { + const char* config_prefix; + const char* shapes_prefix; + const char* function_name; + const char* mlir_module; + const int32_t* device_ids; + size_t device_ids_size; + int32_t guaranteed_constants_size; + uint64_t function_library_fingerprint; + int32_t num_cores_per_replica; + int32_t num_replicas; + const XLA_TpuMeshState* mesh_state; +}; + +extern "C" { + +// Creates a new TPU program. +XLA_TpuProgram* TpuProgram_New(); + +// Destroys the `tpu_program`. +void TpuProgram_Free(XLA_TpuProgram* tpu_program); + + +// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and +// destroyed, it is in an unusable state. +void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program, + SE_Status* status); + +// Gets TPU program size in bytes from the `tpu_program`. +int64_t TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program); + +// Logs the summary of current memory state snapshot of the `tpu_program`. +bool TpuProgram_LogProgramMemorySummary(const XLA_TpuProgram* tpu_program); + +// Gets TPU program executable info from the `tpu_program`. +void TpuProgram_GetExecutableInfo(const XLA_TpuProgram* tpu_program, + TpuSerializedProto* executable_info); + +// Gets host transfer info proto. +void TpuProgram_GetHostTransferInfo( + const XLA_TpuProgram* tpu_program, + TpuSerializedProto* host_transfer_info); + +// Gets HLO metadata proto. +void TpuProgram_GetHloMetadata(const XLA_TpuProgram* tpu_program, + TpuSerializedProto* hlo_metadata); + +// Returns the number of available TPU core count. +int TpuTopology_AvailableCoreCount(const XLA_TpuMeshState* mesh_state, + TpuCoreTypeEnum tpu_core_type); + +// Creates a unique compilation cache `key` used for `put` and `get` operations. +// Returned buffer is heap-allocated and must be owned. +const char* TpuCompile_CreateCompilationCacheKey( + CompilationCacheKeyProperty property); + +// Creates a guaranteed const fingerprint. Guarantee const is normally used in +// TPU inference to avoid re-copying unchanged variables onto the TPU device. +// It promises the value is identical for every execution in the same session +// even if the actual value changes in later executions. +uint64_t TpuCompile_CreateGuaranteedConstFingerprint(uint64_t fingerprint, + const char* data, + size_t size); + +// Checks if whether a TPU compilation is enabled. +bool TpuCompile_IsTpuCompilationEnabled(); + +// Executes the computations using XLA TPU compiler and returns TPU programs +// ready for execution. +void TpuCompile_CompileAheadOfTime( + TpuSerializedProto aot_compilation_request, + XLA_TpuProgram** tpu_programs[], + size_t* count, SE_Status* status); + +// Builds `DeviceAssignment` from `TpuCompileMetadata` serialized proto. +void TpuCompile_BuildXLADeviceAssignment( + TpuSerializedProto serialized_tpu_compile_metadata, + const XLA_TpuMeshState* mesh_state, + TpuSerializedProto* serialized_device_assignment, SE_Status* status); + +// Converts an XLA `Shape` into its equivalent TPU `Shape` representation. +void TpuCompile_ToTpuShapeRepresentation( + TpuSerializedProto serialized_xla_shape, int data_type, + bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape, + SE_Status* status); + +} // extern "C" + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_options.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_options.cc new file mode 100644 index 00000000000..49a2a089adf --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_options.cc @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h" + +namespace tensorflow { +namespace internal { + +namespace { +static bool tpu_compilation_cancellation_terminates_process = true; +static bool tpu_compilation_failure_closes_chips = true; +} // namespace + +void SetTpuCompilationCancellationTerminatesProcess(bool b) { + tpu_compilation_cancellation_terminates_process = b; +} + +bool TpuCompilationCancellationTerminatesProcess() { + return tpu_compilation_cancellation_terminates_process; +} + +void SetTpuCompilationFailureClosesChips(bool value) { + tpu_compilation_failure_closes_chips = value; +} + +bool TpuCompilationFailureClosesChips() { + return tpu_compilation_failure_closes_chips; +} + +} // namespace internal +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_options.h b/tensorflow/core/tpu/kernels/tpu_compile_op_options.h new file mode 100644 index 00000000000..b81fe4a3b75 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_options.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_ + +#include + +namespace tensorflow { +namespace internal { + +// Setter and getter that determine how TPUCompile responds to cancelled +// compilation. By default this is true, meaning cancelled compilation will +// abort the process, since that's the only mechanism we have available. +// +// Setting this to false allows the process to remain alive, and should only be +// used in tests. +void SetTpuCompilationCancellationTerminatesProcess(bool b); +bool TpuCompilationCancellationTerminatesProcess(); + +// Setter and getter that determine whether TPU compilation failure will cause +// chips to close. By default this is true, it is suitable for training. For +// inference, we never want servers to die and thus chips will keep alive. +// See b/109873767. +void SetTpuCompilationFailureClosesChips(bool value); +bool TpuCompilationFailureClosesChips(); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc new file mode 100644 index 00000000000..d42c604fd1e --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc @@ -0,0 +1,439 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" + +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/dump.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" +#include "tensorflow/stream_executor/tpu/proto_helper.h" + +namespace tensorflow { +namespace tpu { + +using stream_executor::port::Status; +using stream_executor::port::StatusOr; +using xla::ComputationLayout; +using xla::DebugOptions; +using xla::DeviceAssignment; +using xla::HloModuleConfig; +using xla::HloSharding; +using xla::InvalidArgument; +using xla::ProgramShape; +using xla::Shape; +using xla::ShapeTree; +using xla::ShapeUtil; + +Status ValidateResultShape(const Shape& client_shape, + const Shape& result_shape) { + TF_RETURN_IF_ERROR( + xla::ShapeUtil::ValidateShapeWithOptionalLayout(client_shape)); + if (!xla::ShapeUtil::Compatible(client_shape, result_shape)) { + return InvalidArgument( + "Shape used to set computation result layout %s is not compatible " + "with result shape %s", + xla::ShapeUtil::HumanStringWithLayout(client_shape), + xla::ShapeUtil::HumanString(result_shape)); + } + return Status::OK(); +} + +StatusOr> CreateModuleConfig( + const ProgramShape& program_shape, absl::Span argument_shapes, + absl::optional result_layout, + absl::optional device_assignment, int replica_count, + int num_partitions, const DebugOptions* debug_options, const int* seed, + const int* launch_id, const bool* alias_passthrough_params, + const xla::FusionConfigCollection* fusion_config_collection, + const std::vector>* fusion_config) { + auto config = absl::make_unique(program_shape); + ComputationLayout* computation_layout = + config->mutable_entry_computation_layout(); + if (program_shape.parameters_size() != argument_shapes.size()) { + return InvalidArgument("computation takes %d parameters, but %u given", + program_shape.parameters_size(), + argument_shapes.size()); + } + for (int i = 0; i < argument_shapes.size(); ++i) { + // Verify that shape of arguments matches the shape of the arguments in the + // ProgramShape. + if (!ShapeUtil::Compatible(argument_shapes[i], + program_shape.parameters(i))) { + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", + i, ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(argument_shapes[i])); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + argument_shapes[i])); + } + + if (result_layout.has_value()) { + TF_RETURN_IF_ERROR( + ValidateResultShape(result_layout.value(), program_shape.result())); + TF_RETURN_IF_ERROR( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + result_layout.value())); + } else { + // If the result layout is not set, then choose the default. + computation_layout->mutable_result_layout()->SetToDefaultLayout(); + } + + config->set_replica_count(replica_count); + config->set_num_partitions(num_partitions); + if (seed != nullptr) { + config->set_seed(*seed); + } + if (launch_id != nullptr) { + config->set_launch_id(*launch_id); + } + if (debug_options != nullptr) { + config->set_debug_options(*debug_options); + } else { + config->set_debug_options(xla::GetDebugOptionsFromFlags()); + } + + // TODO(henrytan): set intra_op_parallelism_threads. + // Reference: + // tensorflow/compiler/xla/service/service.cc?l=324. + + if (device_assignment.has_value()) { + config->set_static_device_assignment(device_assignment.value()); + } + + if (alias_passthrough_params != nullptr) { + config->set_alias_passthrough_params(*alias_passthrough_params); + } + + if (fusion_config_collection != nullptr && fusion_config != nullptr && + *fusion_config_collection != xla::FusionConfigCollection::kOff) { + config->set_fusion_config_collection(*fusion_config_collection); + *config->mutable_fusion_config() = *fusion_config; + } + + return std::move(config); +} + +StatusOr> CreateModuleConfig( + const xla::ProgramShape& program_shape, + absl::Span argument_shapes, + absl::optional result_layout, + absl::optional device_assignment, int replica_count, + int num_partitions, const DebugOptions* debug_options) { + return CreateModuleConfig(program_shape, argument_shapes, result_layout, + device_assignment, replica_count, num_partitions, + debug_options, /*seed=*/nullptr, + /*launch_id=*/nullptr, + /*alias_passthrough_params=*/nullptr, + /*fusion_config_collection=*/nullptr, + /*fusion_config=*/nullptr); +} + +ShapeTree GetSubtree( + const ShapeTree& tuple_shape_tree, int element_index) { + ShapeTree element_shape_tree( + xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(), + element_index), + HloSharding::Replicate()); + + xla::ShapeIndex src_index; + src_index.push_back(element_index); + element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {}); + return element_shape_tree; +} + +Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding, + int64 device) { + if (shape.IsTuple()) { + ShapeTree tuple_shape_tree = sharding.GetAsShapeTree(shape); + std::vector arg_shapes; + for (int64 i = 0; i < xla::ShapeUtil::TupleElementCount(shape); ++i) { + Shape element_shape = xla::ShapeUtil::GetTupleElementShape(shape, i); + HloSharding element_sharding = tuple_shape_tree.element({i}); + if (element_shape.IsTuple()) { + element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i)); + } + if (element_sharding.UsesDevice(device)) { + arg_shapes.push_back( + GetPerDeviceShape(element_shape, element_sharding, device)); + } + } + return xla::ShapeUtil::MakeTupleShape(arg_shapes); + } + + if (sharding.IsTileMaximal()) { + return shape; + } + + std::vector dimensions; + std::vector offset = sharding.TileOffsetForDevice(shape, device); + std::vector limit = sharding.TileLimitForDevice(shape, device); + for (int64 i = 0; i < limit.size(); ++i) { + dimensions.push_back(limit[i] - offset[i]); + } + if (shape.has_layout()) { + return xla::ShapeUtil::MakeShapeWithLayout(shape.element_type(), dimensions, + shape.layout().minor_to_major()); + } + return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions); +} + +Status AddVariableUpdatesToCores( + const TPUCompileMetadataProto& metadata, + const XlaCompiler::CompilationResult& compilation_result, + const std::vector& arg_core_mapping, + std::vector* may_modify_variables, + std::vector>* per_core_output_shapes, + std::vector>>* per_core_variable_indices) { + // Add all variables to the corresponding core. + may_modify_variables->resize(metadata.num_cores_per_replica(), false); + int resource_update_pos = 0; + for (int i = 0; i < metadata.args_size(); ++i) { + const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i); + if (proto_arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) { + const auto& sharding = proto_arg.sharding(); + bool updated = false; + if (resource_update_pos < compilation_result.resource_updates.size()) { + const XlaCompiler::ResourceUpdate& update = + compilation_result.resource_updates[resource_update_pos]; + if (update.input_index == i) { + updated = true; + int pos = compilation_result.outputs.size() + resource_update_pos; + xla::Shape shape = xla::ShapeUtil::GetTupleElementShape( + compilation_result.xla_output_shape, pos); + auto add_to_core = [&](int64 core, const xla::Shape& per_core_shape) { + (*per_core_output_shapes)[core].push_back(per_core_shape); + (*may_modify_variables)[core] = + (*may_modify_variables)[core] || update.modified; + }; + if (sharding.type() == xla::OpSharding::MAXIMAL) { + add_to_core(sharding.tile_assignment_devices(0), shape); + } else if (sharding.type() == xla::OpSharding::OTHER) { + auto sharding_or = + xla::HloSharding::FromProto(proto_arg.sharding()); + TF_RET_CHECK(sharding_or.ok()); + for (int64 core : proto_arg.sharding().tile_assignment_devices()) { + xla::Shape per_core_shape = + GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core); + add_to_core(core, per_core_shape); + } + } else { + TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED); + for (int64 core = 0; core < metadata.num_cores_per_replica(); + ++core) { + add_to_core(core, shape); + } + } + ++resource_update_pos; + } + } + if (sharding.type() == xla::OpSharding::MAXIMAL) { + (*per_core_variable_indices)[sharding.tile_assignment_devices(0)] + .push_back( + std::pair(arg_core_mapping[i].indices[0], updated)); + } else if (sharding.type() == xla::OpSharding::OTHER) { + for (int core : sharding.tile_assignment_devices()) { + (*per_core_variable_indices)[core].push_back( + std::pair(arg_core_mapping[i].indices[core], updated)); + } + } else { + TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED); + for (int64 core = 0; core < metadata.num_cores_per_replica(); ++core) { + (*per_core_variable_indices)[core].push_back( + std::pair(arg_core_mapping[i].indices[core], updated)); + } + } + } + } + return Status::OK(); +} + +Status ComputeOutputShapesForEachCore( + const tpu::TPUCompileMetadataProto& metadata, + const XlaCompiler::CompilationResult& compilation_result, + std::vector>* per_core_output_shapes) { + for (int i = 0; i < metadata.retvals_size(); ++i) { + const tpu::TPUCompileMetadataProto::Retval& retval = metadata.retvals(i); + TF_RET_CHECK(!compilation_result.outputs[i].is_constant) + << "TPU compilation output " << i + << " has a compile-time constant value. " + "This should never happen."; + + xla::Shape shape = xla::ShapeUtil::GetTupleElementShape( + compilation_result.xla_output_shape, i); + auto add_shape_to_core = [&](int core, xla::Shape per_core_shape) { + (*per_core_output_shapes)[core].push_back(std::move(per_core_shape)); + }; + if (retval.sharding().type() == xla::OpSharding::MAXIMAL) { + add_shape_to_core(retval.sharding().tile_assignment_devices(0), + std::move(shape)); + } else if (retval.sharding().type() == xla::OpSharding::OTHER) { + auto sharding_or = xla::HloSharding::FromProto(retval.sharding()); + TF_RET_CHECK(sharding_or.ok()); + for (int64 core : retval.sharding().tile_assignment_devices()) { + xla::Shape per_core_shape = + GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core); + add_shape_to_core(core, std::move(per_core_shape)); + } + } else { + TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED) + << "Not all of the constant tensors were consumed."; + for (int core = 0; core < per_core_output_shapes->size(); ++core) { + add_shape_to_core(core, shape); + } + } + } + return Status::OK(); +} + +Status CreateHloModules( + const TPUCompileMetadataProto& metadata, + const tensorflow::XlaCompiler::CompilationResult& compilation_result, + const absl::optional& device_assignment, + std::vector>* hlo_modules) { + TF_RET_CHECK( + compilation_result.computation->proto().has_host_program_shape()); + + auto debug_options = xla::DebugOptions(); + debug_options.set_xla_step_marker_location(metadata.step_marker_location()); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig( + xla::ProgramShape( + compilation_result.computation->proto().host_program_shape()), + compilation_result.xla_input_shapes, + compilation_result.xla_output_shape, device_assignment, + metadata.num_replicas(), metadata.num_cores_per_replica(), + &debug_options)); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + xla::HloModule::CreateFromProto(compilation_result.computation->proto(), + *module_config)); + DumpHloModuleIfEnabled(*hlo_module, "before_optimizations"); + hlo_modules->push_back(std::move(hlo_module)); + + return Status::OK(); +} + +XlaCompilationResultProto SerializeCompilationResult( + const XlaCompiler::CompilationResult& compilation_result) { + XlaCompilationResultProto compilation_result_proto; + for (int input_mapping : compilation_result.input_mapping) { + compilation_result_proto.add_input_mappings(input_mapping); + } + + for (const Shape& input_shape : compilation_result.xla_input_shapes) { + *(compilation_result_proto.add_xla_input_shapes()) = input_shape.ToProto(); + } + *(compilation_result_proto.mutable_xla_output_shape()) = + compilation_result.xla_output_shape.ToProto(); + + for (const XlaCompiler::OutputDescription& output_description : + compilation_result.outputs) { + auto* new_output = compilation_result_proto.add_outputs(); + new_output->set_type(output_description.type); + output_description.shape.AsProto(new_output->mutable_shape()); + new_output->set_is_constant(output_description.is_constant); + output_description.constant_value.AsProtoField( + new_output->mutable_constant_value()); + new_output->set_input_index(output_description.input_index); + new_output->set_is_tensor_list(output_description.is_tensor_list); + } + + *compilation_result_proto.mutable_host_compute_metadata() = + compilation_result.host_compute_metadata; + + for (const XlaCompiler::ResourceUpdate& resource_update : + compilation_result.resource_updates) { + auto* new_resource_update = compilation_result_proto.add_resource_updates(); + new_resource_update->set_input_index(resource_update.input_index); + new_resource_update->set_type(resource_update.type); + resource_update.shape.AsProto(new_resource_update->mutable_shape()); + new_resource_update->set_modified(resource_update.modified); + for (const std::string& gradient_access : + resource_update.tensor_array_gradients_accessed) { + new_resource_update->mutable_tensor_array_gradients_accessed()->insert( + {gradient_access, true}); + } + } + + if (compilation_result.computation != nullptr) { + *compilation_result_proto.mutable_computation() = + compilation_result.computation->proto(); + } + + return compilation_result_proto; +} + +StatusOr CreateTpuAotCompilationRequest( + const xla::HloModuleGroup& module_group, + const XlaCompiler::CompilationResult& compilation_result, + const TPUCompileMetadataProto& metadata, + const std::vector>& per_core_arg_shapes, + const std::vector>& per_core_output_shapes, + const std::vector>>& + per_core_variable_indices, + const absl::optional& device_assignment) { + VLOG(1) << "CreateTpuAotCompilationRequest."; + TpuAotCompilationRequestProto aot_request; + *(aot_request.mutable_hlo_module_group()) = module_group.ToProto(); + *(aot_request.mutable_metadata()) = metadata; + if (device_assignment.has_value()) { + xla::DeviceAssignmentProto device_assignment_proto; + Status status = device_assignment->Serialize(&device_assignment_proto); + if (!status.ok()) { + return status; + } + *(aot_request.mutable_device_assignment()) = device_assignment_proto; + } + + for (const auto& arg_shapes : per_core_arg_shapes) { + auto* new_shape_list = aot_request.add_per_core_arg_shapes(); + for (const auto& arg_shape : arg_shapes) { + *new_shape_list->add_shapes() = arg_shape.ToProto(); + } + } + + for (const auto& output_shapes : per_core_output_shapes) { + auto* new_shape_list = aot_request.add_per_core_output_shapes(); + for (const auto& output_shape : output_shapes) { + *new_shape_list->add_shapes() = output_shape.ToProto(); + } + } + + for (const auto& variable_indices : per_core_variable_indices) { + auto* new_list = aot_request.add_per_core_variable_indices(); + for (const auto& variable_index : variable_indices) { + auto* core_index = new_list->add_variable_indices(); + core_index->set_index(variable_index.first); + core_index->set_updated(variable_index.second); + } + } + + XlaCompilationResultProto compilation_result_proto = + SerializeCompilationResult(compilation_result); + *aot_request.mutable_compilation_result() = compilation_result_proto; + + VLOG(1) << "TpuAotCompilationRequest:\n" << aot_request.DebugString(); + return aot_request; +} +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h new file mode 100644 index 00000000000..0f21e458828 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h @@ -0,0 +1,122 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_ + +#include +#include + +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace tensorflow { +namespace tpu { + +namespace se = ::stream_executor; + +// Describes the position of an argument or return value after the computation +// has been partitioned into cores. +struct ShardingAndIndex { + // Sharding across cores. + ::xla::OpSharding sharding; + // Argument/return value number. If sharding is single-core, `indices` has a + // single element; otherwise, it has num_cores elements. + std::vector indices; +}; + +// TODO(b/158279168): Dedup with internal version. +// Return the per-device shape for a `shape` with a given `sharding`. +xla::Shape GetPerDeviceShape(const xla::Shape& shape, + const xla::HloSharding& sharding, + int64 device); + +stream_executor::port::StatusOr> +CreateModuleConfig( + const xla::ProgramShape& program_shape, + absl::Span argument_shapes, + absl::optional result_layout, + absl::optional device_assignment, + int replica_count, int num_partitions, + const xla::DebugOptions* debug_options, const int* seed, + const int* launch_id, const bool* alias_passthrough_params, + const xla::FusionConfigCollection* fusion_config_collection, + const std::vector>* fusion_config); + +stream_executor::port::StatusOr> +CreateModuleConfig( + const xla::ProgramShape& program_shape, + absl::Span argument_shapes, + absl::optional result_layout, + absl::optional device_assignment, + int replica_count, + int num_partitions, const xla::DebugOptions* debug_options); + +xla::ShapeTree GetSubtree( + const xla::ShapeTree& tuple_shape_tree, + int element_index); + +xla::Shape GetPerDeviceShape(const xla::Shape& shape, + const xla::HloSharding& sharding, + int64 device); + +Status AddVariableUpdatesToCores( + const TPUCompileMetadataProto& metadata, + const XlaCompiler::CompilationResult& compilation_result, + const std::vector& arg_core_mapping, + std::vector* may_modify_variables, + std::vector>* per_core_output_shapes, + std::vector>>* per_core_variable_indices); + +se::port::Status ComputeOutputShapesForEachCore( + const tpu::TPUCompileMetadataProto& metadata, + const XlaCompiler::CompilationResult& compilation_result, + std::vector>* per_core_output_shapes); + +se::port::Status CreateHloModules( + const TPUCompileMetadataProto& metadata, + const XlaCompiler::CompilationResult& compilation_result, + const absl::optional& device_assignment, + std::vector>* hlo_modules); + +se::port::StatusOr +CreateTpuAotCompilationRequest( + const xla::HloModuleGroup& module_group, + const XlaCompiler::CompilationResult& compilation_result, + const TPUCompileMetadataProto& metadata, + const std::vector>& per_core_arg_shapes, + const std::vector>& per_core_output_shapes, + const std::vector>>& + per_core_variable_indices, + const absl::optional& device_assignment); +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc new file mode 100644 index 00000000000..7fa345d735c --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc @@ -0,0 +1,298 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h" + +#include + +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" +#include "tensorflow/core/tpu/tpu_config_c_api.h" +#include "tensorflow/core/tpu/tpu_configuration.h" +#include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/stream_executor/tpu/proto_helper.h" + +namespace tensorflow { +namespace { + +Status GetTpuMeshStateInterface(const ResourceMgr* rmgr, + tpu::TpuMeshStateInterface** state) { + if (!rmgr->Lookup(rmgr->default_container(), + tpu::kTpuMeshCommonStateResourceName, state) + .ok()) { + return errors::FailedPrecondition( + "The TPU system has not been initialized."); + } + return Status::OK(); +} + +// Attempt to delete resource_name from resource_manager's default_container. +// Returns OK if the deletion succeeded, or if the resource was not found. Else +// return the deletion error. +template +Status DeleteIfExists(ResourceMgr* resource_manager, + const char* resource_name) { + VLOG(1) << "Removing resource " << resource_name << " if it exists"; + Status status = resource_manager->Delete( + resource_manager->default_container(), resource_name); + if (status.ok()) { + VLOG(1) << "Removed existing resource " << resource_name; + return Status::OK(); + } + if (status.code() == error::NOT_FOUND) { + VLOG(1) << "No resource " << resource_name << " to remove"; + return Status::OK(); + } + VLOG(1) << "Error removing resource " << resource_name << " : " << status; + return status; +} + +} // namespace + +void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "ConfigureDistributedTpuOp"; + XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp"); + + std::vector num_devices_per_host; + int chips_per_host = -1; + for (int i = 0; i < ctx->num_inputs(); ++i) { + const Tensor& input_tensor = ctx->input(i); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(input_tensor.shape()), + errors::InvalidArgument("Input ", i, " should be a scalar but has ", + input_tensor.dims(), " dimensions")); + if (chips_per_host == -1) { + chips_per_host = input_tensor.scalar()(); + } else { + OP_REQUIRES( + ctx, chips_per_host == input_tensor.scalar()(), + errors::Internal("Host ", i, " has ", input_tensor.scalar()(), + " TPU chips but host 0 has ", chips_per_host)); + } + num_devices_per_host.push_back(input_tensor.scalar()()); + } + + TF_Status* status = TF_NewStatus(); + size_t host_config_output_size; + char* host_config_output; + + auto* rmgr = GetTPUConfigResourceMgr(); + OP_REQUIRES_OK(ctx, DeleteIfExists( + rmgr, tpu::kTpuMeshCommonStateResourceName)); + + ConfigureDistributedTpuOp_DoWork( + num_devices_per_host.size(), num_devices_per_host.data(), + &host_config_output_size, &host_config_output, status); + + OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(), + tpu::kTpuMeshCommonStateResourceName, + tpu::TpuMeshStateInterface::Create())); + + Tensor* ctx_output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output)); + ctx_output->scalar()() = + std::string(host_config_output, host_config_output_size); + + OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); + TF_DeleteStatus(status); + TpuConfigurationApi_FreeCharArray(host_config_output); + + VLOG(1) << "ConfigureDistributedTpuOp done"; +} + +void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "WaitForDistributedTpuOp"; + XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp"); + + size_t num_devices_per_host = -1; + size_t num_hosts = ctx->num_inputs(); + + for (int i = 0; i < ctx->num_inputs(); ++i) { + const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i); + OP_REQUIRES( + ctx, host_ordinal_to_global_device_id_tensor.dims() == 1, + errors::InvalidArgument("Input ", i, " should be a vector but has ", + host_ordinal_to_global_device_id_tensor.dims(), + " dimensions")); + } + + std::vector> mapping; + std::vector mapping_arg; + + mapping.resize(ctx->num_inputs()); + + for (int i = 0; i < ctx->num_inputs(); ++i) { + const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i); + const auto host_ordinal_to_global_device_id = + host_ordinal_to_global_device_id_tensor.flat(); + if (num_devices_per_host == -1) { + num_devices_per_host = + host_ordinal_to_global_device_id_tensor.dim_size(0); + } else { + OP_REQUIRES(ctx, + num_devices_per_host == + host_ordinal_to_global_device_id_tensor.dim_size(0), + errors::Internal( + "Host ", i, " has ", + host_ordinal_to_global_device_id_tensor.dim_size(0), + " TPU devices but host 0 has ", num_devices_per_host)); + } + for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0); + ++j) { + int32_t global_device_id = host_ordinal_to_global_device_id(j); + mapping[i].push_back(global_device_id); + } + mapping_arg.push_back(mapping[i].data()); + } + + TF_Status* status = TF_NewStatus(); + size_t tpu_topology_output_size; + char* tpu_topology_output; + + tpu::TpuMeshStateInterface* mesh_state; + auto* rmgr = GetTPUConfigResourceMgr(); + OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state)); + core::ScopedUnref mesh_state_unref(mesh_state); + + WaitForDistributedTpuOp_DoWork( + num_hosts, num_devices_per_host, + const_cast(mapping_arg.data()), mesh_state, + &tpu_topology_output_size, &tpu_topology_output, status); + + Tensor* ctx_output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output)); + ctx_output->scalar()() = + std::string(tpu_topology_output, tpu_topology_output_size); + + OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); + TF_DeleteStatus(status); + TpuConfigurationApi_FreeCharArray(tpu_topology_output); + + VLOG(1) << "WaitForDistributedTpuOp done"; +} + +void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "ShutdownDistributedTpuOp"; + XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp"); + + TF_Status* status = TF_NewStatus(); + OP_REQUIRES_OK(ctx, DeleteIfExists( + GetTPUConfigResourceMgr(), + tpu::kTpuMeshCommonStateResourceName)); + ShutdownDistributedTpuOp_DoWork(status); + OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); + TF_DeleteStatus(status); + + VLOG(1) << "ShutdownDistributedTpuOp done"; +} + +void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "InitializeHostForDistributedTpuOp"; + XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp"); + + auto tpu_host_config = ctx->input(0).scalar()(); + + size_t device_id_output_size; + int32_t* device_id_output; + TF_Status* status = TF_NewStatus(); + + InitializeHostForDistributedTpuOp_DoWork( + tpu_host_config.size(), tpu_host_config.data(), + enable_whole_mesh_compilations_, &device_id_output_size, + &device_id_output, status); + + Tensor* ctx_output; + OP_REQUIRES_OK( + ctx, ctx->allocate_output( + 0, TensorShape({static_cast(device_id_output_size)}), + &ctx_output)); + + for (size_t i = 0; i < device_id_output_size; ++i) { + ctx_output->flat()(i) = device_id_output[i]; + } + + OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); + TF_DeleteStatus(status); + TpuConfigurationApi_FreeInt32Array(device_id_output); + + VLOG(1) << "InitializeHostForDistributedTpuOp done"; +} + +void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "SetGlobalTPUArrayOp"; + XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp"); + + auto tpu_topology = ctx->input(0).scalar()(); + TF_Status* status = TF_NewStatus(); + + SetGlobalTPUArrayOp_DoWork(tpu_topology.size(), tpu_topology.data(), status); + + OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); + TF_DeleteStatus(status); + + VLOG(1) << "SetGlobalTPUArrayOp done"; +} + +void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "DisconnectDistributedTpuChipsOp"; + XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp"); + + TF_Status* status = TF_NewStatus(); + int32_t number_of_chips_output = 0; + + DisconnectDistributedTpuChipsOp_DoWork(&number_of_chips_output, status); + + Tensor* ctx_output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output)); + ctx_output->scalar()() = number_of_chips_output; + + OP_REQUIRES_OK(ctx, StatusFromTF_Status(status)); + TF_DeleteStatus(status); + + VLOG(1) << "DisconnectDistributedTpuChipsOp done"; +} + +// These ops execute on the TPU_SYSTEM device only. +REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU") + .Device(DEVICE_TPU_SYSTEM) + .HostMemory("output"), + ConfigureDistributedTpuOp); +REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU") + .Device(DEVICE_TPU_SYSTEM) + .HostMemory("inputs") + .HostMemory("topology"), + WaitForDistributedTpuOp); +REGISTER_KERNEL_BUILDER( + Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM), + ShutdownDistributedTpuOp); +REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU") + .Device(DEVICE_TPU_SYSTEM) + .HostMemory("input") + .HostMemory("tpu_ids"), + InitializeHostForDistributedTpuOp); +REGISTER_KERNEL_BUILDER( + Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"), + SetGlobalTPUArrayOp); +REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem") + .Device(DEVICE_TPU_SYSTEM) + .HostMemory("number_of_tpu_chips"), + DisconnectDistributedTpuChipsOp); + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.h b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h new file mode 100644 index 00000000000..f75a47e5aaf --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.h @@ -0,0 +1,156 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// The ConfigureDistributedTpu op is used to start an TPUDriver from +// TensorFlow. It should be run on a TPU_SYSTEM device and returns the +// connection host:port for the CompilationCacheServer. The +// CompilationCacheServer will remain live until the device's Resource Manager +// is cleared or a ShutdownDistributedTpuOp is run on the same device. +class ConfigureDistributedTpuOp : public OpKernel { + public: + explicit ConfigureDistributedTpuOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES( + ctx, ctx->num_inputs() > 0, + errors::Internal("_ConfigureDistributedTPU needs at least one input")); + } + void Compute(OpKernelContext* ctx) override; + ~ConfigureDistributedTpuOp() override {} + + private: + // ConfigureDistributedTpuOp is neither copyable nor movable. + ConfigureDistributedTpuOp(const ConfigureDistributedTpuOp&) = delete; + ConfigureDistributedTpuOp& operator=(const ConfigureDistributedTpuOp&) = + delete; +}; + +// The WaitForDistributedTpuOp op is used to block execution until +// the distributed Tpu system has started up. It must be run on +// the same TPU_SYSTEM device that ConfigureDistributedTpuOp was run +// on, after all of the InitializeHostForDistributedTpuOp Ops have +// completed. +class WaitForDistributedTpuOp : public OpKernel { + public: + explicit WaitForDistributedTpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, + ctx->GetAttr("startup_timeout_sec", &startup_timeout_sec_)); + OP_REQUIRES(ctx, startup_timeout_sec_ > 0, + errors::InvalidArgument("startup_timeout_sec ", + startup_timeout_sec_, " must be >0")); + } + void Compute(OpKernelContext* ctx) override; + ~WaitForDistributedTpuOp() override {} + + private: + // The time to wait for all hosts to start up. + int startup_timeout_sec_; + + // WaitForDistributedTpuOp is neither copyable nor movable. + WaitForDistributedTpuOp(const WaitForDistributedTpuOp&) = delete; + WaitForDistributedTpuOp& operator=(const WaitForDistributedTpuOp&) = delete; +}; + +// The ShutdownDistributedTpu op is used to stop a running TPUDriver from +// TensorFlow. It should be run on the TPU_SYSTEM device where +// ConfigureDistributedTpuOp was run. +class ShutdownDistributedTpuOp : public OpKernel { + public: + explicit ShutdownDistributedTpuOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; + + ~ShutdownDistributedTpuOp() override {} + + private: + // ShutdownDistributedTpuOp is neither copyable nor movable. + ShutdownDistributedTpuOp(const ShutdownDistributedTpuOp&) = delete; + ShutdownDistributedTpuOp& operator=(const ShutdownDistributedTpuOp&) = delete; +}; + +// The InitializeHostForDistributedTpu op is used to initialize the +// TPUPlatform on a host in a distributed TPU system. It should be +// run on every host containing TPU devices before any other Ops that use +// TPU are run. +class InitializeHostForDistributedTpuOp : public OpKernel { + public: + explicit InitializeHostForDistributedTpuOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + ctx->GetAttr("enable_whole_mesh_compilations", + &enable_whole_mesh_compilations_) + .IgnoreError(); + } + + void Compute(OpKernelContext* ctx) override; + + ~InitializeHostForDistributedTpuOp() override {} + + private: + // InitializeHostForDistributedTpuOp is neither copyable nor movable. + InitializeHostForDistributedTpuOp(const InitializeHostForDistributedTpuOp&) = + delete; + InitializeHostForDistributedTpuOp& operator=( + const InitializeHostForDistributedTpuOp&) = delete; + + bool enable_whole_mesh_compilations_ = false; +}; + +// The SetGlobalTPUArray op is used to initialize the TPUPlatform on a +// host in a distributed TPU system. It should be run on every host +// containing TPU devices before any other Ops that use TPU are run. +class SetGlobalTPUArrayOp : public OpKernel { + public: + explicit SetGlobalTPUArrayOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; + + ~SetGlobalTPUArrayOp() override {} + + private: + // SetGlobalTPUArrayOp is neither copyable nor movable. + SetGlobalTPUArrayOp(const SetGlobalTPUArrayOp&) = delete; + SetGlobalTPUArrayOp& operator=(const SetGlobalTPUArrayOp&) = delete; +}; + +// The DisconnectDistributedTpuChips op is used to disconnect all the chips on a +// host from a running TPUDriver instance. It should be run on every host +// containing TPU devices before the ShutdownDistributedTpuOp is run on +// the TPU_SYSTEM. +class DisconnectDistributedTpuChipsOp : public OpKernel { + public: + explicit DisconnectDistributedTpuChipsOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; + + ~DisconnectDistributedTpuChipsOp() override {} + + private: + // DisconnectDistributedTpuChipsOp is neither copyable nor movable. + DisconnectDistributedTpuChipsOp(const DisconnectDistributedTpuChipsOp&) = + delete; + DisconnectDistributedTpuChipsOp& operator=( + const DisconnectDistributedTpuChipsOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_executable_info.proto b/tensorflow/core/tpu/kernels/tpu_executable_info.proto new file mode 100644 index 00000000000..359dad03a72 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_executable_info.proto @@ -0,0 +1,94 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; + +import "tensorflow/compiler/xla/service/hlo.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; + +// A serialization of TPUExecutable. Only includes fields necessary to load +// and execute a program on a worker node. +message TPUExecutableInfoProto { + reserved 1; + + // The shapes of the inputs and outputs. + repeated xla.ShapeProto input_shapes = 2; + reserved 7; // was input_shape + xla.ShapeProto output_shape = 3; + + message UpdateIndexPair { + int32 index = 1; + bool updated = 2; + } + + message ShapeIndex { + repeated int32 index = 1; + } + + // Dynamic output indices indicate which outputs have dynamic dimensions. + repeated ShapeIndex dynamic_output_indices = 11; + + // For each resource variable output, what was the index of the corresponding + // input and was it updated? The indices are sorted by input order. + repeated UpdateIndexPair variable_indices = 10; + + // The shapes of the outputs when represented as Tensors. These may not + // match the output_shape values because we may flatten tensors to avoid + // excess padding. + repeated TensorShapeProto output_tensor_shapes = 8; + + reserved 4; + + // Optional session module for passing XLA computations between TPUCompileOp + // and TPUExecuteOp. This is needed to support the + // --xla_dump_hlo_snapshots flag. + xla.HloSnapshot session_module = 5; + + // The physical device ids assigned to the replicated cores. + xla.DeviceAssignmentProto device_assignment = 6; +} + +// Metadata for a data transfer between device and host. +message TPUHostTransferProto { + enum TransferDirection { + NONE = 0; + DEVICE_TO_HOST = 1; + HOST_TO_DEVICE = 2; + } + // Channel identifier assigned by compiler and used in host commands. + int64 channel = 1; + // Direction of the transfer operation. + TransferDirection direction = 2; + // Channel identifier prodided by XLA client. + string key = 3; + // Depth of nested loops for this transfer operation. + int64 nested_while_level = 4; + // Shape of the data to be transferred (including layout). + xla.ShapeProto shape = 5; + // Address of the device buffer in HBM (byte offset). + int64 buffer_offset = 6; + // Original data type for this host transfer before X64 rewrite. + xla.PrimitiveType original_type = 7; + // If this host transfer is a splitted X64 transfer, sepcifies whether this + // transfer is for lower bits. + bool is_lower_bits = 8; +} + +message TPUHostTransferInfoProto { + repeated TPUHostTransferProto host_transfers = 1; +} diff --git a/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h b/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h new file mode 100644 index 00000000000..cb6a82efabc --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_ + +typedef struct XLA_TpuMeshState XLA_TpuMeshState; + +// Creates a new TPU mesh state object. +XLA_TpuMeshState* TpuMeshState_Create(); + +// Deletes the given TPU `mesh_state` object. Once deleted the object is +// unusable. +void TpuMeshState_Free(XLA_TpuMeshState* mesh_state); + +// Returns a pointer to an opaque mesh data structure used internally. +void* TpuMeshState_MeshCommonState(XLA_TpuMeshState* mesh_state); + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_C_API_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h b/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h new file mode 100644 index 00000000000..34202a78718 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_ +#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_ + +#include + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" + +namespace tensorflow { + +class TpuMeshCommonState; + +namespace tpu { + +const char kTpuMeshCommonStateResourceName[] = "tpu_mesh_common_state"; + +class TpuMeshStateInterface : public tensorflow::ResourceBase { + public: + explicit TpuMeshStateInterface(XLA_TpuMeshState* handle) + : mesh_state_(handle) { + } + + ~TpuMeshStateInterface() override { + if (mesh_state_ != nullptr) { + TpuMeshState_Free(mesh_state_); + } + } + + static TpuMeshStateInterface* Create() { + return new TpuMeshStateInterface(TpuMeshState_Create()); + } + + const XLA_TpuMeshState* data() const { return mesh_state_; } + + tensorflow::TpuMeshCommonState* mesh_common_state() const { + return static_cast( + TpuMeshState_MeshCommonState(mesh_state_)); + } + + // Returns whether we should include the device assignment as a static field + // to the TPU program. This also determines whether we should include the + // device assignment as part of the compilation cache key. + bool NeedsStaticDeviceAssignment( + const TPUCompileMetadataProto& metadata, + TpuCoreTypeEnum tpu_core_type) const { + // Static device assignment enables XLA to perform certain optimization when + // all cores are used in the replicated computation. + return metadata.num_cores_per_replica() * metadata.num_replicas() == + TpuTopology_AvailableCoreCount(mesh_state_, + tpu_core_type); + } + + string DebugString() const override { return "TpuMeshStateInterface"; } + + private: + XLA_TpuMeshState* mesh_state_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_MESH_STATE_INTERFACE_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_program.cc b/tensorflow/core/tpu/kernels/tpu_program.cc new file mode 100644 index 00000000000..7d89ad15ae9 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_program.cc @@ -0,0 +1,201 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_program.h" + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/stream_executor/tpu/proto_helper.h" +#include "tensorflow/stream_executor/tpu/status_helper.h" + +namespace tensorflow { +namespace tpu { + +namespace { + +namespace se_tpu = ::stream_executor::tpu; + +using stream_executor::port::StatusOr; +using xla::Shape; + +StatusOr> CompileAheadOfTime( + std::unique_ptr module_group, + const XlaCompiler::CompilationResult& compilation_result, + const TPUCompileMetadataProto& metadata, + const std::vector>& per_core_arg_shapes, + const std::vector>& per_core_output_shapes, + const std::vector>>& + per_core_variable_indices, + const absl::optional& device_assignment) { + VLOG(1) << "Run CompileAheadOfTime."; + TF_ASSIGN_OR_RETURN(TpuAotCompilationRequestProto aot_request, + CreateTpuAotCompilationRequest( + *module_group, compilation_result, metadata, + per_core_arg_shapes, per_core_output_shapes, + per_core_variable_indices, device_assignment)); + se_tpu::SerializedProto serialized_aot_request = + se_tpu::SerializeProto(aot_request); + auto cleanup = gtl::MakeCleanup([serialized_aot_request] { + se_tpu::SerializedProto_Free(serialized_aot_request); + }); + + XLA_TpuProgram** xla_tpu_programs = nullptr; + size_t count = 0; + StatusHelper status; + VLOG(1) << "Run TpuCompile_CompileAheadOfTime."; + TpuCompile_CompileAheadOfTime(serialized_aot_request, &xla_tpu_programs, + &count, status.c_status); + VLOG(1) << "Run CompileAheadOfTime completed."; + if (!status.status().ok()) { + return status.status(); + } + std::vector tpu_programs(count, nullptr); + for (size_t i = 0; i < count; ++i) { + tpu_programs[i] = xla_tpu_programs[i]; + } + delete[] xla_tpu_programs; + return tpu_programs; + return Status::OK(); +} + +StatusOr> CompileAheadOfTime( + const TPUCompileMetadataProto& metadata, + const XlaCompiler::CompilationResult& compilation_result, + const std::vector>& per_core_arg_shapes, + const std::vector>& per_core_output_shapes, + const std::vector>>& + per_core_variable_indices, + const absl::optional& device_assignment) { + VLOG(1) << "Compile Tpu programs."; + std::vector> hlo_modules; + auto status = CreateHloModules(metadata, compilation_result, + device_assignment, &hlo_modules); + if (!status.ok()) { + return status; + } + + return CompileAheadOfTime( + absl::make_unique(hlo_modules[0]->name(), + absl::MakeSpan(hlo_modules)), + compilation_result, metadata, per_core_arg_shapes, per_core_output_shapes, + per_core_variable_indices, device_assignment); +} + +} // namespace + +int64_t TpuProgram::program_size() const { + int64_t total_size = 0; + for (XLA_TpuProgram* tpu_program : tpu_programs_) { + total_size += TpuProgram_GetProgramSize(tpu_program); + } + return total_size; +} + +bool TpuProgram::LogProgramMemorySummary() { + bool success = true; + for (const XLA_TpuProgram* tpu_program : tpu_programs_) { + success &= TpuProgram_LogProgramMemorySummary(tpu_program); + } + return success; +} + +void TpuProgram::UnloadAndDestroyPrograms() { + for (XLA_TpuProgram* tpu_program : tpu_programs_) { + StatusHelper status; + TpuProgram_UnloadAndDestroy(tpu_program, status.c_status); + auto s = status.status(); + if (!s.ok()) { + LOG(ERROR) << "TpuProgram::UnloadPrograms(): " << s.ToString(); + } + } + tpu_programs_.clear(); +} + +/*static*/ Status TpuProgram::Build( + const TPUCompileMetadataProto& metadata, + const tensorflow::XlaCompiler::CompilationResult& compilation_result, + const std::vector& arg_core_mapping, + const std::vector>& per_core_arg_shapes, + const absl::optional& xla_device_assignment, + TpuProgram* tpu_program) { + std::vector> per_core_output_shapes( + metadata.num_cores_per_replica()); + TF_RETURN_IF_ERROR(ComputeOutputShapesForEachCore( + metadata, compilation_result, &per_core_output_shapes)); + + std::vector>> per_core_variable_indices( + metadata.num_cores_per_replica()); + std::vector may_modify_variables; + TF_RETURN_IF_ERROR(AddVariableUpdatesToCores( + metadata, compilation_result, arg_core_mapping, &may_modify_variables, + &per_core_output_shapes, &per_core_variable_indices)); + TF_RET_CHECK(per_core_arg_shapes.size() == metadata.num_cores_per_replica()); + TF_RET_CHECK(per_core_output_shapes.size() == per_core_arg_shapes.size()); + TF_RET_CHECK(per_core_output_shapes.size() == + per_core_variable_indices.size()); + tpu_program->set_may_modify_variables(may_modify_variables); + + // With shardable input/output pairs, XLA could generate separate + // sharding/unsharding programs along with the main program. The + // sharding/unsharding programs will be in nested entries of the AOT + // compilation result. + auto status_or = CompileAheadOfTime( + metadata, compilation_result, per_core_arg_shapes, per_core_output_shapes, + per_core_variable_indices, xla_device_assignment); + + TF_ASSIGN_OR_RETURN(std::vector xla_tpu_programs, + std::move(status_or)); + // SPMD could return 1 result for all partitions. + TF_RET_CHECK(xla_tpu_programs.size() == 1 || + xla_tpu_programs.size() == metadata.num_cores_per_replica()); + tpu_program->set_tpu_programs(xla_tpu_programs); + + // TODO(jiawenhao): Handle the case of xla_tpu_programs.size() > 1. + TpuSerializedProto serialized_executable_info; + TpuProgram_GetExecutableInfo(xla_tpu_programs[0], + &serialized_executable_info); + TPUExecutableInfoProto executable_info = + se_tpu::DeserializeProto( + serialized_executable_info); + tpu_program->set_executable_info(executable_info); + StreamExecutor_Tpu_FreeSerializedProto(&serialized_executable_info); + + TPUHostTransferInfoProto host_transfer_info; + TpuSerializedProto serialized_host_transfer_info; + TpuProgram_GetHostTransferInfo(xla_tpu_programs[0], + &serialized_host_transfer_info); + if (serialized_host_transfer_info.size > 0) { + host_transfer_info = se_tpu::DeserializeProto( + serialized_host_transfer_info); + StreamExecutor_Tpu_FreeSerializedProto(&serialized_host_transfer_info); + } + tpu_program->set_host_transfer_info(host_transfer_info); + + TpuSerializedProto serialized_hlo_metadata; + TpuProgram_GetHloMetadata(xla_tpu_programs[0], &serialized_hlo_metadata); + xla::HloProto hlo_metadata = + se_tpu::DeserializeProto(serialized_hlo_metadata); + tpu_program->set_hlo_metadata(hlo_metadata); + StreamExecutor_Tpu_FreeSerializedProto(&serialized_hlo_metadata); + + return Status::OK(); +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_program.h b/tensorflow/core/tpu/kernels/tpu_program.h new file mode 100644 index 00000000000..aee55bd2f48 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_program.h @@ -0,0 +1,161 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_ +#define EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" + +namespace tensorflow { +namespace tpu { + +class TpuAotCompilationOptions : public xla::AotCompilationOptions { + public: + explicit TpuAotCompilationOptions(int64 replica_count) + : num_cores_(0), replica_count_(replica_count) {} + + // Returns the ID of the platform to which these options apply. + se::Platform::Id PlatformId() const override { + LOG(FATAL) << "Not implemented."; + return nullptr; + }; + + void set_num_cores(int64 tpu_cores) { num_cores_ = tpu_cores; } + int64 replica_count() const override { return replica_count_; } + int64 num_cores() const override { return num_cores_; } + + void set_allow_separate_sharding_programs(bool allow) { + allow_separate_sharding_programs_ = allow; + } + bool allow_separate_sharding_programs() const { + return allow_separate_sharding_programs_; + } + + const std::vector + shardable_value_update_pairs() const { + return shardable_value_update_pairs_; + } + void set_shardable_value_update_pairs( + std::vector pairs) { + shardable_value_update_pairs_ = std::move(pairs); + } + + private: + int64 num_cores_; + int64 replica_count_; + + // Whether to allow the compiler to create separte sharding and unsharding + // programs, and modify the original program's input/output sharded size. This + // is used for XLA-chosen sharding on parameters without an on-device loop: + // the caller can invoke sharding first, then (repeatedly) invoke the sharded + // main program, and finally invoke the unsharding program when it needs the + // full output. + bool allow_separate_sharding_programs_ = false; + + // The list of input/output pairs in the main program that could be sharded. + std::vector + shardable_value_update_pairs_; +}; + +// An executable capable of being fed to a TPU device. +class TpuProgram { + public: + using Status = ::stream_executor::port::Status; + + virtual ~TpuProgram() = default; + + static Status Build( + const TPUCompileMetadataProto& metadata, + const tensorflow::XlaCompiler::CompilationResult& compilation_result, + const std::vector& arg_core_mapping, + const std::vector>& per_core_arg_shapes, + const absl::optional& xla_device_assignment, + TpuProgram* tpu_program); + + size_t program_count() const { + return tpu_programs_.size(); + } + + int64_t program_size() const; + + bool LogProgramMemorySummary(); + + void UnloadAndDestroyPrograms(); + + const std::vector& may_modify_variables() const { + return may_modify_variables_; + } + void set_may_modify_variables(const std::vector& may_modify_variables) { + may_modify_variables_ = may_modify_variables; + } + + const tf2xla::HostComputeMetadata& host_compute_metadata() const { + return host_compute_metadata_; + } + void set_host_compute_metadata( + const tf2xla::HostComputeMetadata& host_compute_metadata) { + host_compute_metadata_ = host_compute_metadata; + } + + const std::vector& tpu_programs() const { + return tpu_programs_; + } + void set_tpu_programs(std::vector tpu_programs) { + tpu_programs_ = tpu_programs; + } + + const TPUExecutableInfoProto& executable_info() const { + return executable_info_; + } + void set_executable_info(const TPUExecutableInfoProto& executable_info) { + executable_info_ = executable_info; + } + + const TPUHostTransferInfoProto& host_transfer_info() const { + return host_transfer_info_; + } + void set_host_transfer_info( + const TPUHostTransferInfoProto& host_transfer_info) { + host_transfer_info_ = host_transfer_info; + } + + const xla::HloProto& hlo_metadata() const { return hlo_metadata_; } + void set_hlo_metadata(const xla::HloProto& hlo_metadata) { + hlo_metadata_ = hlo_metadata; + } + + private: + std::vector may_modify_variables_; + tf2xla::HostComputeMetadata host_compute_metadata_; + + std::vector tpu_programs_; // Not owned. + TPUExecutableInfoProto executable_info_; + TPUHostTransferInfoProto host_transfer_info_; + xla::HloProto hlo_metadata_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // EXPERIMENTAL_BRAIN_TPU_1VM_MINIEXECUTOR_TPU_PROGRAM_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_util.cc b/tensorflow/core/tpu/kernels/tpu_util.cc new file mode 100644 index 00000000000..5c286de7672 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_util.cc @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_util.h" + +#include "absl/strings/str_split.h" +#include "tensorflow/core/platform/random.h" + +namespace tensorflow { +namespace tpu { + +std::string SessionNameFromMetadata(const SessionMetadata* session_metadata) { + return session_metadata ? session_metadata->name() : ""; +} + +std::string ProtoKeyForComputation(const std::string& key, int core) { + return absl::StrCat(key, ":", core); +} + +xla::StatusOr ParseCompilationCacheKey( + const std::string& key) { + const std::vector splits = absl::StrSplit(key, '|'); + if (splits.size() == 1) { + // No guaranteed_const. + return TpuCompilationCacheKey(key); + } else if (splits.size() != 3) { + return errors::InvalidArgument("Invalid TPU compilation cache key:", key); + } + + TpuCompilationCacheKey parsed_key(splits.at(0)); + parsed_key.has_guaranteed_const = true; + parsed_key.session_handle = splits.at(1); + const string fingerprint = splits.at(2); + parsed_key.guaranteed_const_fingerprint = [fingerprint] { + return fingerprint; + }; + return parsed_key; +} + +xla::CompileOnlyClient::AotXlaComputationInstance +BuildAotXlaComputationInstance( + const XlaCompiler::CompilationResult& compilation_result) { + xla::CompileOnlyClient::AotXlaComputationInstance instance; + instance.computation = compilation_result.computation.get(); + for (const xla::Shape& shape : compilation_result.xla_input_shapes) { + instance.argument_layouts.push_back(&shape); + } + instance.result_layout = &compilation_result.xla_output_shape; + return instance; +} + +Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape) { + if (tensor.dtype() != DT_INT64 || + !TensorShapeUtils::IsVector(tensor.shape())) { + return errors::InvalidArgument("Shape tensor must be an int64 vector."); + } + const int64 rank = tensor.NumElements(); + auto tensor_dims = tensor.flat(); + std::vector dims(rank); + for (int64 i = 0; i < rank; ++i) { + dims[i] = tensor_dims(i); + } + return TensorShapeUtils::MakeShape(dims, shape); +} + +Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes, + std::vector* shapes) { + shapes->resize(dynamic_shapes.size()); + for (int i = 0; i < dynamic_shapes.size(); ++i) { + TF_RETURN_IF_ERROR( + ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i])); + } + return Status::OK(); +} + +Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes, + std::vector* shapes) { + shapes->resize(dynamic_shapes.end() - dynamic_shapes.begin()); + size_t i = 0; + for (auto& dynamic_shape : dynamic_shapes) { + TF_RETURN_IF_ERROR( + ShapeTensorToTensorShape(dynamic_shape.tensor(), &(*shapes)[i])); + ++i; + } + return Status::OK(); +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_util.h b/tensorflow/core/tpu/kernels/tpu_util.h new file mode 100644 index 00000000000..0ca94d0af59 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_util.h @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_ + +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" + +namespace tensorflow { +namespace tpu { + +// Utility to get session_name from `SessionMetadata`. `SessionMetadata` may +// be null. +std::string SessionNameFromMetadata(const SessionMetadata* session_metadata); + +// Generates cache proto key for a given computation on a TPU core. +std::string ProtoKeyForComputation(const std::string& key, int core); + +// Returns a TpuCompilationCacheKey parsed from given key or an error. +xla::StatusOr ParseCompilationCacheKey( + const std::string& key); + +xla::CompileOnlyClient::AotXlaComputationInstance +BuildAotXlaComputationInstance( + const XlaCompiler::CompilationResult& compilation_result); + +// Returns true if TPU compilation is enabled. +bool IsTpuCompilationEnabled(); + +// Converts an int64 host memory `tensor` to a `shape`. +Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape); + +Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes, + std::vector* shapes); +Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes, + std::vector* shapes); + +// Given a tensor of `shape` and `type`, as what shape should it be stored on +// the TPU device? This function tranposes or flattens the excessively-padded +// tensors to rank 1, but leaves other tensor shapes alone. +xla::StatusOr TpuShapeRepresentation(const TensorShape& shape, + DataType type, + bool use_fast_memory); +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_ diff --git a/tensorflow/core/tpu/kernels/trace_util.h b/tensorflow/core/tpu/kernels/trace_util.h new file mode 100644 index 00000000000..4e0b7c96892 --- /dev/null +++ b/tensorflow/core/tpu/kernels/trace_util.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_ + +#ifdef PLATFORM_GOOGLE +#include "base/tracer.h" +#else +#undef TRACESTRING +#define TRACESTRING(x) +#undef TRACELITERAL +#define TRACELITERAL(x) +#endif + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_ diff --git a/tensorflow/core/tpu/tpu_library_loader.cc b/tensorflow/core/tpu/tpu_library_loader.cc index 3bc835c9c7f..c89de142a9f 100644 --- a/tensorflow/core/tpu/tpu_library_loader.cc +++ b/tensorflow/core/tpu/tpu_library_loader.cc @@ -64,13 +64,20 @@ TfTpu_ConfigApiFn* ConfigApiFn() { } Status InitializeTpuLibrary(void* library_handle) { + bool shared_object_loaded = true; if (library_handle == nullptr) { library_handle = dlopen(nullptr, RTLD_LAZY); + shared_object_loaded = false; } TF_RETURN_IF_ERROR(SetTpuInitializeStructFns(library_handle)); TF_RETURN_IF_ERROR(SetTpuConfigStructFns(library_handle)); + if (shared_object_loaded) { + // Initialize TPU platform when the platform code is loaded from a library. + InitializeApiFn()->TfTpu_InitializeFn(); + } + return Status::OK(); } diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD new file mode 100644 index 00000000000..52ca40f8d3f --- /dev/null +++ b/tensorflow/stream_executor/tpu/BUILD @@ -0,0 +1,234 @@ +# Description: StreamExecutor Interface for TPUs + +package( + default_visibility = ["//tensorflow/core/tpu:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tpu_executor_c_api_hdrs", + hdrs = ["tpu_executor_c_api.h"], + deps = [ + "//tensorflow/c:tf_attrtype", + "//tensorflow/c:tf_datatype", + "//tensorflow/c:tf_status", + ], +) + +cc_library( + name = "tpu_node_context_c_api_hdrs", + hdrs = ["tpu_node_context_c_api.h"], + deps = [ + ":tpu_executor_c_api_hdrs", + ], +) + +cc_library( + name = "status_helper", + hdrs = ["status_helper.h"], + deps = [ + ":tpu_executor_c_api_hdrs", + "//tensorflow/core/platform:status", + ], +) + +cc_library( + name = "c_api_conversions", + hdrs = ["c_api_conversions.h"], + deps = [ + ":tpu_executor_c_api_hdrs", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/stream_executor:stream", + "@com_google_absl//absl/container:inlined_vector", + ], +) + +cc_library( + name = "proto_helper", + srcs = ["proto_helper.cc"], + hdrs = ["proto_helper.h"], + deps = ["//tensorflow/core:lib"], +) + +cc_library( + name = "tpu_stream", + hdrs = ["tpu_stream.h"], + deps = [ + ":tpu_executor_c_api_hdrs", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/lib", + ], +) + +cc_library( + name = "tpu_timer", + hdrs = ["tpu_timer.h"], + deps = [ + ":tpu_executor_c_api_hdrs", + "//tensorflow/core/platform:types", + "//tensorflow/stream_executor:stream", + ], +) + +cc_library( + name = "tpu_executor", + srcs = ["tpu_executor.cc"], + hdrs = ["tpu_executor.h"], + deps = [ + ":c_api_conversions", + ":status_helper", + ":tpu_executor_c_api_hdrs", + ":tpu_executor_interface", + ":tpu_platform", + ":tpu_platform_interface", + ":tpu_stream", + ":tpu_timer", + "//tensorflow/c:tf_status", + "//tensorflow/core:lib", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +cc_library( + name = "tpu_executor_hdrs", + hdrs = ["tpu_executor.h"], + deps = [ + ":tpu_executor_c_api_hdrs", + ":tpu_executor_interface", + ":tpu_platform_hdrs", + ":tpu_platform_interface", + "//tensorflow/core/platform:types", + "//tensorflow/stream_executor:stream_header", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +cc_library( + name = "tpu_platform_hdrs", + hdrs = ["tpu_platform.h"], + deps = [ + ":tpu_executor_c_api_hdrs", + ":tpu_platform_interface", + "//tensorflow/core/platform:types", + "//tensorflow/stream_executor:stream_header", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +cc_library( + name = "tpu_node_context", + srcs = ["tpu_node_context.cc"], + hdrs = ["tpu_node_context.h"], + deps = [ + ":status_helper", + ":tpu_executor_c_api_hdrs", + ":tpu_node_context_c_api_hdrs", + ":tpu_platform_interface", + ":tpu_transfer_manager", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service:stream_pool", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/core:framework", + "//tensorflow/stream_executor:device_memory_allocator", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "tpu_platform", + srcs = ["tpu_platform.cc"], + hdrs = ["tpu_platform.h"], + deps = [ + ":status_helper", + ":tpu_executor_c_api_hdrs", + ":tpu_executor_hdrs", + ":tpu_platform_interface", + "//tensorflow/c:tf_status", + "//tensorflow/core/platform:types", + "//tensorflow/stream_executor:stream", + "@com_google_absl//absl/container:flat_hash_map", + ], + alwayslink = True, +) + +cc_library( + name = "tpu_transfer_manager", + srcs = ["tpu_transfer_manager_registration.cc"], + deps = [ + ":tpu_platform", + ":tpu_transfer_manager_base", + "//tensorflow/compiler/xla/service:transfer_manager", + ], +) + +cc_library( + name = "tpu_transfer_manager_base", + srcs = ["tpu_transfer_manager.cc"], + hdrs = ["tpu_transfer_manager.h"], + deps = [ + ":c_api_conversions", + ":proto_helper", + ":status_helper", + ":tpu_executor_c_api_hdrs", + ":tpu_platform", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/stream_executor:stream", + ], +) + +cc_library( + name = "tpu_computation_placer", + srcs = ["tpu_computation_placer.cc"], + hdrs = ["tpu_computation_placer.h"], + deps = [ + ":tpu_executor_c_api_hdrs", + ":tpu_platform", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:computation_placer", + ], + alwayslink = True, +) + +cc_library( + name = "tpu_platform_interface", + srcs = ["tpu_platform_interface.cc"], + hdrs = ["tpu_platform_interface.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/platform:types", + "//tensorflow/stream_executor:multi_platform_manager", + "//tensorflow/stream_executor:stream_executor_headers", + ], +) + +cc_library( + name = "tpu_stream_interface", + hdrs = ["tpu_stream_interface.h"], + visibility = ["//visibility:public"], + deps = ["//tensorflow/stream_executor:stream_executor_internal"], +) + +cc_library( + name = "tpu_executor_interface", + hdrs = ["tpu_executor_interface.h"], + visibility = ["//visibility:public"], + deps = [ + ":tpu_platform_interface", + "//tensorflow/core/platform:errors", + "//tensorflow/stream_executor:stream_executor_internal", + "//tensorflow/stream_executor:stream_header", + ], +) diff --git a/tensorflow/stream_executor/tpu/c_api_conversions.h b/tensorflow/stream_executor/tpu/c_api_conversions.h new file mode 100644 index 00000000000..1bb9ecee688 --- /dev/null +++ b/tensorflow/stream_executor/tpu/c_api_conversions.h @@ -0,0 +1,115 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_ + +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" + +class TpuConversions { + public: + static stream_executor::DeviceMemoryBase + SE_DeviceMemoryBaseToDeviceMemoryBase(SE_DeviceMemoryBase se_base) { + stream_executor::DeviceMemoryBase base(se_base.opaque, se_base.size); + base.SetPayload(se_base.payload); + return base; + } + + static SE_DeviceMemoryBase DeviceMemoryBaseToSE_DeviceMemoryBase( + const stream_executor::DeviceMemoryBase& base) { + SE_DeviceMemoryBase se_base; + se_base.opaque = const_cast(base.opaque()); + se_base.payload = base.payload(); + se_base.size = base.size(); + return se_base; + } + + static xla::Shape CShapeToXlaShape(XLA_Shape* shape) { + xla::ShapeProto p; + p.ParseFromArray(shape->bytes, shape->size); + return xla::Shape(p); + } + + static void XlaShapeToCShape(const xla::Shape& xla_shape, + XLA_Shape* c_shape) { + xla::ShapeProto p = xla_shape.ToProto(); + std::string p_str = p.SerializeAsString(); + c_shape->bytes = new char[p_str.size()]; + c_shape->size = p_str.size(); + memcpy(c_shape->bytes, p_str.data(), p_str.size()); + } + + static void XLAShapedBufferToCShapedBuffer( + const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer) { + XlaShapeToCShape(buffer.on_host_shape(), &c_device_buffer->on_host_shape); + XlaShapeToCShape(buffer.on_device_shape(), + &c_device_buffer->on_device_shape); + c_device_buffer->device_ordinal = buffer.device_ordinal(); + absl::InlinedVector bases; + for (auto& pair : buffer.buffers()) { + bases.push_back(DeviceMemoryBaseToSE_DeviceMemoryBase(pair.second)); + } + c_device_buffer->count = bases.size(); + c_device_buffer->bases = new SE_DeviceMemoryBase[bases.size()]; + for (int i = 0; i < bases.size(); ++i) { + c_device_buffer->bases[i] = bases[i]; + } + } + + static void XLALiteralToCLiteral(const xla::LiteralSlice& literal, + XLA_Literal* c_literal) { + XlaShapeToCShape(literal.shape(), &c_literal->shape); + auto shapes = xla::ShapeUtil::GetLeafShapes(literal.shape()); + c_literal->buffers = new char*[shapes.size()]; + c_literal->sizes = new size_t[shapes.size()]; + c_literal->count = shapes.size(); + for (int i = 0; i < shapes.size(); ++i) { + c_literal->buffers[i] = reinterpret_cast( + const_cast(literal.untyped_data(shapes[i].index))); + c_literal->sizes[i] = literal.size_bytes(shapes[i].index); + } + } + + static xla::MutableBorrowingLiteral CLiteralToXLALiteral( + XLA_Literal* c_literal) { + xla::Shape shape = CShapeToXlaShape(&c_literal->shape); + LOG(INFO) << "Shape: " << shape.DebugString(); + return xla::MutableBorrowingLiteral( + absl::MakeSpan(c_literal->buffers, c_literal->count), shape); + } + + static void CShapeCleanup(XLA_Shape* c_shape) { delete[] c_shape->bytes; } + + static void CLiteralCleanup(XLA_Literal* c_literal) { + delete[] c_literal->buffers; + delete[] c_literal->sizes; + CShapeCleanup(&c_literal->shape); + } + + static void CShapedBufferCleanup(XLA_ShapedBuffer* c_buffer) { + CShapeCleanup(&c_buffer->on_device_shape); + CShapeCleanup(&c_buffer->on_host_shape); + delete[] c_buffer->bases; + } +}; + +#endif // THIRD_PARTY_TENSORFLOW_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_ diff --git a/tensorflow/stream_executor/tpu/proto_helper.cc b/tensorflow/stream_executor/tpu/proto_helper.cc new file mode 100644 index 00000000000..db663c6671b --- /dev/null +++ b/tensorflow/stream_executor/tpu/proto_helper.cc @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/tpu/proto_helper.h" + +extern "C" { + +void StreamExecutor_Tpu_FreeSerializedProto(const TpuSerializedProto* proto) { + CHECK_NE(proto, nullptr); + CHECK_NE(proto->bytes, nullptr); + CHECK_GT(proto->size, 0); + delete[] proto->bytes; +} + +} // extern "C" diff --git a/tensorflow/stream_executor/tpu/proto_helper.h b/tensorflow/stream_executor/tpu/proto_helper.h new file mode 100644 index 00000000000..3bd2b09f95e --- /dev/null +++ b/tensorflow/stream_executor/tpu/proto_helper.h @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_ + +#include + +#include "tensorflow/core/platform/logging.h" + +extern "C" { + +typedef struct TpuSerializedProto { + const char* bytes; + size_t size; +} TpuSerializedProto; + +void StreamExecutor_Tpu_FreeSerializedProto(const TpuSerializedProto* proto); + +} // extern "C" + +namespace stream_executor { +namespace tpu { + +using SerializedProto = TpuSerializedProto; + +// Serializes a proto and put the result in the given SerializedProto* argument. +// +// Users should call SerializedProto_Free on `serialized_proto` afterwards. +template +inline void SerializeProto(const Proto& proto, + SerializedProto* serialized_proto) { + auto size = proto.ByteSizeLong(); + auto bytes = new char[size]; + CHECK(proto.SerializeToArray(bytes, size)); + serialized_proto->size = size; + serialized_proto->bytes = bytes; +} + +// Serializes a proto and return the result as a SerializedProto value. +// +// Users should call SerializedProto_Free on the return value afterwards. +template +inline SerializedProto SerializeProto(const Proto& proto) { + SerializedProto serialized_proto; + SerializeProto(proto, &serialized_proto); + return serialized_proto; +} + +// Deserializes a buffer and return the corresponding proto. If the buffer is +// empty, return an empty proto. +template +inline Proto DeserializeProto(const SerializedProto& serialized_proto) { + Proto proto; + if (serialized_proto.bytes != nullptr) { + CHECK_GT(serialized_proto.size, 0); + CHECK(proto.ParseFromArray(serialized_proto.bytes, serialized_proto.size)) + << "Invalid buffer, failed to deserialize buffer."; + } + return proto; +} + +// Releases the memory allocated for serialized protos. +inline void SerializedProto_Free(const SerializedProto& serialized_proto) { + CHECK_NE(serialized_proto.bytes, nullptr); + CHECK_GT(serialized_proto.size, 0); + delete[] serialized_proto.bytes; +} + +} // namespace tpu +} // namespace stream_executor + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_PROTO_HELPER_H_ diff --git a/tensorflow/stream_executor/tpu/status_helper.h b/tensorflow/stream_executor/tpu/status_helper.h new file mode 100644 index 00000000000..8fcf302edac --- /dev/null +++ b/tensorflow/stream_executor/tpu/status_helper.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_ + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" + +struct StatusHelper { + StatusHelper() : c_status(TpuStatus_New()) {} + ~StatusHelper() { TpuStatus_Free(c_status); } + bool ok() { return TpuStatus_Code(c_status) == 0; } + tensorflow::Status status() { + if (!ok()) { + return tensorflow::Status( + tensorflow::error::Code(TpuStatus_Code(c_status)), + TpuStatus_Message(c_status)); + } else { + return tensorflow::Status::OK(); + } + } + SE_Status* c_status; +}; + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_STATUS_HELPER_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_computation_placer.cc b/tensorflow/stream_executor/tpu/tpu_computation_placer.cc new file mode 100644 index 00000000000..660b446d953 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_computation_placer.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/tpu/tpu_computation_placer.h" + +#include "tensorflow/stream_executor/tpu/tpu_platform.h" + +template +using StatusOr = TpuComputationPlacer::StatusOr; + +TpuComputationPlacer::TpuComputationPlacer() { + placer_ = TpuComputationPlacer_New(); +} + +TpuComputationPlacer::~TpuComputationPlacer() { + TpuComputationPlacer_Free(placer_); +} + +StatusOr TpuComputationPlacer::DeviceId(int replica, int computation, + int replica_count, + int computation_count) { + LOG(FATAL) << "Unimplemented."; +} + +StatusOr TpuComputationPlacer::AssignDevices( + int replica_count, int computation_count) { + LOG(FATAL) << "Unimplemented."; +} + +static std::unique_ptr CreateTpuComputationPlacer() { + return std::make_unique(); +} + +static bool InitModule() { + xla::ComputationPlacer::RegisterComputationPlacer( + tensorflow::TpuPlatform::kId, CreateTpuComputationPlacer); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/stream_executor/tpu/tpu_computation_placer.h b/tensorflow/stream_executor/tpu/tpu_computation_placer.h new file mode 100644 index 00000000000..c8f4c9e3888 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_computation_placer.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_ + +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" + +class TpuComputationPlacer : public xla::ComputationPlacer { + public: + template + using StatusOr = xla::StatusOr; + + TpuComputationPlacer(); + ~TpuComputationPlacer() override; + + StatusOr DeviceId(int replica, int computation, int replica_count, + int computation_count) override; + + StatusOr AssignDevices(int replica_count, + int computation_count) override; + + private: + XLA_ComputationPlacer* placer_; +}; + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_COMPUTATION_PLACER_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_executor.cc b/tensorflow/stream_executor/tpu/tpu_executor.cc new file mode 100644 index 00000000000..92808936467 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_executor.cc @@ -0,0 +1,355 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/tpu/tpu_executor.h" + +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/tpu/c_api_conversions.h" +#include "tensorflow/stream_executor/tpu/status_helper.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" +#include "tensorflow/stream_executor/tpu/tpu_stream.h" +#include "tensorflow/stream_executor/tpu/tpu_timer.h" + +using stream_executor::DeviceMemoryBase; + +namespace tensorflow { + +namespace { +using ::stream_executor::port::Status; +} // namespace + +TpuExecutor::~TpuExecutor() { TpuExecutor_Free(executor_); } + +Status TpuExecutor::Init(int device_ordinal, + ::stream_executor::DeviceOptions device_options) { + StatusHelper status; + SE_DeviceOptions* options = + TpuExecutor_NewDeviceOptions(device_options.flags()); + TpuExecutor_Init(executor_, device_ordinal, options, status.c_status); + TpuExecutor_FreeDeviceOptions(options); + return status.status(); +} + +int TpuExecutor::PlatformDeviceCount() { + return TpuExecutor_PlatformDeviceCount(executor_); +} + +void TpuExecutor::SyncAndForgetFailedStreams() { + TpuExecutor_SyncAndForgetFailedStreams(executor_); +} + +bool TpuExecutor::SynchronizeAllActivity() { + return TpuExecutor_SynchronizeAllActivity(executor_); +} + +Status TpuExecutor::BlockHostUntilDone(Stream* stream) { + StatusHelper status; + TpuExecutor_BlockHostUntilDone( + executor_, stream_map().at(stream->implementation()), status.c_status); + return status.status(); +} + +Status TpuExecutor::BlockUntilDoneOrFailed() { + StatusHelper status; + TpuExecutor_BlockUntilDoneOrFailed(executor_, status.c_status); + return status.status(); +} + +Status TpuExecutor::GetStatus(Stream* stream) { + StatusHelper status; + TpuExecutor_GetStatus(executor_, stream_map().at(stream->implementation()), + status.c_status); + return status.status(); +} + +bool TpuExecutor::AllocateStream(Stream* stream) { + return TpuExecutor_AllocateStream(executor_, + stream_map().at(stream->implementation())); +} + +void TpuExecutor::DeallocateStream(Stream* stream) { + TpuExecutor_DeallocateStream(executor_, + stream_map().at(stream->implementation())); + stream_map().erase(stream->implementation()); +} + +bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) { + return TpuExecutor_CreateStreamDependency( + executor_, stream_map().at(dependent->implementation()), + stream_map().at(other->implementation())); +} + +Status TpuExecutor::AllocateEvent(Event* event) { return Status::OK(); } + +Status TpuExecutor::DeallocateEvent(Event* event) { return Status::OK(); } + +// AllocateTimer/DeallocateTimer have no specialization. +bool TpuExecutor::AllocateTimer(Timer* timer) { return true; } + +void TpuExecutor::DeallocateTimer(Timer* timer) {} + +bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) { + return TpuExecutor_StartTimer(executor_, + stream_map().at(stream->implementation()), + timer_map_.at(timer->implementation())); +} + +bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) { + return TpuExecutor_StopTimer(executor_, + stream_map().at(stream->implementation()), + timer_map_.at(timer->implementation())); +} + +stream_executor::Event::Status TpuExecutor::PollForEventStatus( + stream_executor::Event* event) { + return stream_executor::Event::Status(TpuExecutor_PollForEventStatus( + executor_, event_map_.at(event->implementation()))); +} + +Status TpuExecutor::RecordEvent(Stream* stream, + ::stream_executor::Event* event) { + StatusHelper status; + TpuExecutor_RecordEvent(executor_, stream_map().at(stream->implementation()), + event_map_.at(event->implementation()), + status.c_status); + return status.status(); +} + +Status TpuExecutor::WaitForEvent(Stream* stream, + ::stream_executor::Event* event) { + StatusHelper status; + TpuExecutor_WaitForEvent(executor_, stream_map().at(stream->implementation()), + event_map_.at(event->implementation()), + status.c_status); + return status.status(); +} + +// Implementations for Timer, Stream, Event +// We need to map these implementations to internal equivalents -- thus we +// allocate the internal Timer, Stream and Event operations here, and map +// the implementations to the internal values. The "wrapper" interfaces are +// responsible for deallocating the internal value when they are destroyed. + +// Called by Timer::Timer +std::unique_ptr<::stream_executor::internal::TimerInterface> +TpuExecutor::GetTimerImplementation() { + SE_Timer* tpu_timer = TpuTimer_New(executor_); + auto ptr = absl::make_unique(tpu_timer); + timer_map_[ptr.get()] = tpu_timer; + return ptr; +} + +// Called by Stream::Stream +std::unique_ptr<::stream_executor::internal::StreamInterface> +TpuExecutor::GetStreamImplementation() { + SE_Stream* tpu_stream = TpuStream_New(executor_); + auto ptr = absl::make_unique(tpu_stream); + stream_map()[ptr.get()] = tpu_stream; + return ptr; +} + +// Called by Event::Event +std::unique_ptr<::stream_executor::internal::EventInterface> +TpuExecutor::CreateEventImplementation() { + SE_Event* tpu_event = TpuEvent_New(executor_); + auto ptr = absl::make_unique(tpu_event); + event_map_[ptr.get()] = tpu_event; + return ptr; +} + +DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) { + SE_DeviceMemoryBase se_base = + TpuExecutor_Allocate(executor_, size, memory_space); + return TpuConversions::SE_DeviceMemoryBaseToDeviceMemoryBase(se_base); +} + +void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) { + SE_DeviceMemoryBase se_base = + TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(memory); + TpuExecutor_Deallocate(executor_, &se_base); +} + +void TpuExecutor::Deallocate(DeviceMemoryBase* memory) { + SE_DeviceMemoryBase se_base = + TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*memory); + TpuExecutor_Deallocate(executor_, &se_base); +} + +bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const { + int64_t _free; + int64_t _total; + if (TpuExecutor_DeviceMemoryUsage(executor_, &_free, &_total)) { + *free = _free; + *total = _total; + return true; + } + return false; +} + +absl::optional +TpuExecutor::GetAllocatorStats() { + SE_AllocatorStats c_stats; + if (TpuExecutor_GetAllocatorStats(executor_, &c_stats)) { + ::stream_executor::AllocatorStats stats; + stats.num_allocs = c_stats.num_allocs; + stats.bytes_in_use = c_stats.bytes_in_use; + stats.peak_bytes_in_use = c_stats.peak_bytes_in_use; + stats.largest_alloc_size = c_stats.largest_alloc_size; + if (c_stats.has_bytes_limit) { + stats.bytes_limit = c_stats.bytes_limit; + } + stats.bytes_reserved = c_stats.bytes_reserved; + stats.peak_bytes_reserved = c_stats.peak_bytes_reserved; + if (c_stats.has_bytes_reservable_limit) { + stats.bytes_reservable_limit = c_stats.bytes_reservable_limit; + } + stats.largest_free_block_bytes = c_stats.largest_free_block_bytes; + return stats; + } + return {}; +} + +Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) { + StatusHelper status; + TpuExecutor_WaitForInfeedReady(executor_, infeed_queue_index, + status.c_status); + return status.status(); +} + +Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) { + StatusHelper status; + TpuExecutor_WaitForOutfeedReady(executor_, outfeed_queue_index, + status.c_status); + return status.status(); +} + +void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index, + absl::Span bytes, StatusCallback done) { + StatusHelper status; + TpuExecutor_DequeueOutfeed(executor_, outfeed_queue_index, bytes.data(), + bytes.size(), status.c_status); + done(status.status()); +} + +Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index, + absl::Span bytes) { + StatusHelper status; + TpuExecutor_EnqueueInfeed(executor_, infeed_queue_index, bytes.data(), + bytes.size(), status.c_status); + return status.status(); +} + +bool TpuExecutor::Memcpy(Stream* stream, void* host_dst, + const ::stream_executor::DeviceMemoryBase& device_src, + uint64 size) { + SE_DeviceMemoryBase se_base = + TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src); + return TpuExecutor_MemcpyToHost(executor_, + stream_map().at(stream->implementation()), + host_dst, &se_base, size); +} + +bool TpuExecutor::Memcpy(Stream* stream, + ::stream_executor::DeviceMemoryBase* device_dst, + const void* host_src, uint64 size) { + SE_DeviceMemoryBase se_base = + TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst); + return TpuExecutor_MemcpyFromHost(executor_, + stream_map().at(stream->implementation()), + &se_base, host_src, size); +} + +Status TpuExecutor::SynchronousMemcpy( + ::stream_executor::DeviceMemoryBase* device_dst, const void* host_src, + uint64 size) { + StatusHelper status; + SE_DeviceMemoryBase se_base = + TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(*device_dst); + TpuExecutor_SynchronousMemcpyFromHost(executor_, &se_base, host_src, size, + status.c_status); + return status.status(); +} + +Status TpuExecutor::SynchronousMemcpy( + void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src, + uint64 size) { + StatusHelper status; + SE_DeviceMemoryBase se_base = + TpuConversions::DeviceMemoryBaseToSE_DeviceMemoryBase(device_src); + TpuExecutor_SynchronousMemcpyToHost(executor_, host_dst, &se_base, size, + status.c_status); + return status.status(); +} + +Status TpuExecutor::SynchronousMemcpyDeviceToDevice( + ::stream_executor::DeviceMemoryBase* device_dst, + const ::stream_executor::DeviceMemoryBase& device_src, uint64 size) { + return ::stream_executor::port::UnimplementedError( + "This operation not supported on TPU"); +} + +bool TpuExecutor::MemcpyDeviceToDevice( + Stream* stream, ::stream_executor::DeviceMemoryBase* gpu_dst, + const ::stream_executor::DeviceMemoryBase& host_src, uint64 size) { + LOG(FATAL) << __func__ << " not supported on TpuExecutor"; +} + +struct HostCallbackContext { + std::function callback; +}; + +SE_Status* HostCallbackTrampoline(void* ctx) { + HostCallbackContext* host_ctx = reinterpret_cast(ctx); + Status status = host_ctx->callback(); + SE_Status* c_status = + TpuStatus_Create(status.code(), status.error_message().c_str()); + delete host_ctx; + return c_status; +} + +bool TpuExecutor::HostCallback(Stream* stream, + std::function callback) { + HostCallbackContext* ctx = new HostCallbackContext{callback}; + return TpuExecutor_HostCallback(executor_, + stream_map().at(stream->implementation()), + &HostCallbackTrampoline, ctx); +} + +TpuExecutor::StatusOr> +TpuExecutor::CreateDeviceDescription() const { + StatusHelper status; + SE_DeviceDescription* description = TpuDeviceDescription_New(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [description]() { TpuDeviceDescription_Free(description); }); + TpuExecutor_CreateDeviceDescription(executor_, description, status.c_status); + if (status.status().ok()) { + stream_executor::internal::DeviceDescriptionBuilder builder; + CHECK_NE(description->device_vendor, nullptr); + builder.set_device_vendor(description->device_vendor); + builder.set_name(description->name); + builder.set_clock_rate_ghz(description->clock_rate_ghz); + builder.set_core_count(description->core_count); + builder.set_ecc_enabled(description->ecc_enabled); + builder.set_device_memory_size(description->device_memory_size); + builder.set_platform_version(description->platform_version); + return builder.Build(); + } + return status.status(); +} + +} // namespace tensorflow diff --git a/tensorflow/stream_executor/tpu/tpu_executor.h b/tensorflow/stream_executor/tpu/tpu_executor.h new file mode 100644 index 00000000000..5f366421c4c --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_executor.h @@ -0,0 +1,241 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/device_options.h" +#include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/stream_executor/stream_executor_internal.h" +#include "tensorflow/stream_executor/temporary_device_memory.h" +#include "tensorflow/stream_executor/timer.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_platform.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" + +namespace tensorflow { + +class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { + public: + using Status = ::stream_executor::port::Status; + template + using StatusOr = ::stream_executor::port::StatusOr; + using StatusCallback = std::function; + using Stream = ::stream_executor::Stream; + using Event = ::stream_executor::Event; + using Timer = ::stream_executor::Timer; + using DeviceMemoryBase = ::stream_executor::DeviceMemoryBase; + using StreamInterface = ::stream_executor::internal::StreamInterface; + using StreamExecutorInterface = + ::stream_executor::internal::StreamExecutorInterface; + + using EventMap = + absl::flat_hash_map; + using TimerMap = + absl::flat_hash_map; + + explicit TpuExecutor(::tensorflow::tpu::TpuPlatformInterface* platform, + SE_StreamExecutor* executor) + : platform_(platform), executor_(executor) {} + + ~TpuExecutor() override; + + Status Init(int device_ordinal, + ::stream_executor::DeviceOptions device_options) override; + + DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override; + + StatusOr AllocateDeviceMemoryBase(uint64 size, + int64 memory_space); + + Status AllocateEvent(Event* event) override; + + bool AllocateStream(Stream* stream) override; + + bool AllocateTimer(Timer* timer) override; + + Status BlockHostUntilDone(::stream_executor::Stream* stream) override; + + Status BlockUntilDoneOrFailed(); + + StatusOr> + CreateDeviceDescription() const override; + + bool CreateStreamDependency(Stream* dependent, Stream* other) override; + + void DeallocateStream(Stream* stream) override; + + void Deallocate(const DeviceMemoryBase& memory); + + void Deallocate(DeviceMemoryBase* memory) override; + + Status DeallocateEvent(Event* event) override; + + void DeallocateTimer(Timer* timer) override; + + bool DeviceMemoryUsage(int64* free, int64* total) const override; + + void DequeueOutfeed(int32 outfeed_queue_index, absl::Span bytes, + StatusCallback done); + + Status EnqueueInfeed(int32 infeed_queue_index, + absl::Span bytes); + + absl::optional GetAllocatorStats() override; + + Status GetStatus(Stream* stream) override; + + std::unique_ptr<::stream_executor::internal::StreamInterface> + GetStreamImplementation() override; + + std::unique_ptr<::stream_executor::internal::TimerInterface> + GetTimerImplementation() override; + + std::unique_ptr<::stream_executor::internal::EventInterface> + CreateEventImplementation() override; + + bool HostCallback(Stream* stream, std::function callback) override; + + bool Memcpy(Stream* stream, void* host_dst, + const ::stream_executor::DeviceMemoryBase& device_src, + uint64 size) override; + + bool Memcpy(Stream* stream, ::stream_executor::DeviceMemoryBase* device_dst, + const void* host_src, uint64 size) override; + + bool MemcpyDeviceToDevice(Stream* stream, + ::stream_executor::DeviceMemoryBase* gpu_dst, + const ::stream_executor::DeviceMemoryBase& host_src, + uint64 size) override; + + void SyncAndForgetFailedStreams(); + bool SynchronizeAllActivity() override; + + Status SynchronousMemcpy(::stream_executor::DeviceMemoryBase* device_dst, + const void* host_src, uint64 size) override; + Status SynchronousMemcpy( + void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src, + uint64 size) override; + Status SynchronousMemcpyDeviceToDevice( + ::stream_executor::DeviceMemoryBase* device_dst, + const ::stream_executor::DeviceMemoryBase& device_src, + uint64 size) override; + + int PlatformDeviceCount() override; + + Event::Status PollForEventStatus(Event* event) override; + Status RecordEvent(Stream* stream, ::stream_executor::Event* event) override; + Status WaitForEvent(Stream* stream, ::stream_executor::Event* event) override; + + bool StartTimer(Stream* stream, ::stream_executor::Timer* timer) override; + bool StopTimer(Stream* stream, ::stream_executor::Timer* timer) override; + + Status WaitForInfeedReady(int32 infeed_queue_index); + + Status WaitForOutfeedReady(int32 outfeed_queue_index); + + const ::tensorflow::tpu::TpuPlatformInterface& platform() const override { + return *platform_; + } + + ::tensorflow::tpu::TpuPlatformInterface& platform() override { + return *platform_; + } + + // TODO(henrytan): convert this to override once the base interface is changed + // to TpuExecutorInterface. + StatusOr> + CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset, + int64 size) override { + LOG(FATAL) << "Unimplemented."; + } + + // -- Unimplemented (stubbed out) methods. + std::unique_ptr + CreateKernelImplementation() override { + LOG(FATAL) << "Not yet implemented"; + } + + stream_executor::SharedMemoryConfig GetDeviceSharedMemoryConfig() override { + LOG(FATAL) << "not yet implemented"; + } + + void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset, + uint64 size) override { + LOG(FATAL) << "not yet implemented"; + } + Status MemZero(Stream* stream, DeviceMemoryBase* location, + uint64 size) override { + LOG(FATAL) << "not yet implemented"; + } + Status Memset32(Stream* stream, DeviceMemoryBase* location, uint32 pattern, + uint64 size) override { + LOG(FATAL) << "not yet implemented"; + } + Status EnablePeerAccessTo(StreamExecutorInterface* other) override { + LOG(FATAL) << "not yet implemented"; + } + bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override { + LOG(FATAL) << "not yet implemented"; + } + Status SetDeviceSharedMemoryConfig( + stream_executor::SharedMemoryConfig config) override { + LOG(FATAL) << "not yet implemented"; + } + void* HostMemoryAllocate(uint64 size) override { + LOG(FATAL) << "not yet implemented"; + } + void HostMemoryDeallocate(void* mem) override { + LOG(FATAL) << "not yet implemented"; + } + bool HostMemoryRegister(void* mem, uint64 size) override { + LOG(FATAL) << "not yet implemented"; + } + bool HostMemoryUnregister(void* mem) override { + LOG(FATAL) << "not yet implemented"; + } + Status SynchronousMemZero(DeviceMemoryBase* location, uint64 size) override { + LOG(FATAL) << "not yet implemented"; + } + Status SynchronousMemSet(DeviceMemoryBase* location, int value, + uint64 size) override { + LOG(FATAL) << "not yet implemented"; + } + + private: + EventMap event_map_; + TimerMap timer_map_; + + TpuPlatform::StreamMap& stream_map() { + return *(static_cast(platform_)->stream_map()); + } + + ::tensorflow::tpu::TpuPlatformInterface* platform_; + SE_StreamExecutor* executor_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_executor_c_api.h b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h new file mode 100644 index 00000000000..8bf2ecbc938 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_executor_c_api.h @@ -0,0 +1,293 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_ + +#include +#include + +#include "tensorflow/c/tf_attrtype.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" + +typedef struct SE_Platform SE_Platform; +typedef struct SE_StreamExecutor SE_StreamExecutor; +typedef struct SE_Stream SE_Stream; +typedef struct SE_Event SE_Event; +typedef struct SE_Timer SE_Timer; +typedef struct SE_Status SE_Status; + +typedef struct SE_PlatformId { + void* id; // aka stream_executor::Platform::Id +} SE_PlatformId; +typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig; +typedef struct SE_DeviceOptions SE_DeviceOptions; +typedef SE_Status* (*SE_StatusCallbackFn)(void*); + +typedef struct SE_DeviceMemoryBase { + void* opaque; + uint64_t size; + uint64_t payload; +} SE_DeviceMemoryBase; + +typedef struct SE_AllocatorStats { + int64_t num_allocs; + int64_t bytes_in_use; + int64_t peak_bytes_in_use; + int64_t largest_alloc_size; + + bool has_bytes_limit; + int64_t bytes_limit; + + int64_t bytes_reserved; + int64_t peak_bytes_reserved; + + bool has_bytes_reservable_limit; + int64_t bytes_reservable_limit; + + int64_t largest_free_block_bytes; +} SE_AllocatorStats; + +typedef struct SE_DeviceDescription { + char* device_vendor; + char* platform_version; + char* driver_version; + char* runtime_version; + char* pci_bus_id; + char* name; + + int64_t thread_dim_limit_x; + int64_t thread_dim_limit_y; + int64_t thread_dim_limit_z; + int64_t block_dim_limit_x; + int64_t block_dim_limit_y; + int64_t block_dim_limit_z; + + int64_t threads_per_core_limit; + int64_t threads_per_block_limit; + int64_t threads_per_warp; + + int64_t registers_per_core_limit; + int64_t registers_per_block_limit; + + int64_t device_address_bits; + int64_t device_memory_size; + int64_t memory_bandwidth; + + int64_t shared_memory_per_core; + int64_t shared_memory_per_block; + + float clock_rate_ghz; + + int cuda_compute_capability_major; + int cuda_compute_capability_minor; + + int rocm_amdgpu_isa_version; + + int numa_node; + int core_count; + bool ecc_enabled; +} SE_DeviceDescription; + +typedef struct XLA_TransferManager XLA_TransferManager; + +typedef struct XLA_ComputationPlacer XLA_ComputationPlacer; + +// Represents an XLA shape tree. +// Shapes are flattened in default traversal order. +typedef struct XLA_Shape { + char* bytes; + size_t size; +} XLA_Shape; + +// Represents a leaf node for a XLA shaped buffer. +typedef struct XLA_ShapedBuffer { + XLA_Shape on_host_shape; + XLA_Shape on_device_shape; + int device_ordinal; + + SE_DeviceMemoryBase* bases; + size_t count; +} XLA_ShapedBuffer; + +// Represents a leaf XLA literal. +typedef struct XLA_Literal { + char** buffers; + size_t* sizes; + size_t count; + XLA_Shape shape; +} XLA_Literal; + +typedef void (*XLA_CallbackFn)(void*); +typedef void (*XLA_StatusCallbackFn)(void*, SE_Status*); + +extern "C" { + +SE_Platform* TpuPlatform_New(); +void TpuPlatform_Free(SE_Platform* platform); +void TpuPlatform_Initialize(SE_Platform* platform, size_t options_size, + const char** options_key, + const char** options_value, SE_Status* status); +bool TpuPlatform_Initialized(SE_Platform* platform); +SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform, + SE_StreamExecutorConfig* config, + SE_Status* status); +SE_PlatformId TpuPlatform_Id(SE_Platform* platform); +int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform); +int64_t TpuPlatform_TpuMemoryLimit(SE_Platform* platform); + +void TpuExecutor_Init(SE_StreamExecutor* executor, int device_ordinal, + SE_DeviceOptions* device_options, SE_Status* status); +void TpuExecutor_Free(SE_StreamExecutor* executor); + +int TpuExecutor_PlatformDeviceCount(SE_StreamExecutor* executor); + +SE_DeviceMemoryBase TpuExecutor_Allocate(SE_StreamExecutor* executor, + uint64_t size, int64_t memory_space); +void TpuExecutor_Deallocate(SE_StreamExecutor* executor, + SE_DeviceMemoryBase* memory); +bool TpuExecutor_GetAllocatorStats(SE_StreamExecutor* executor, + SE_AllocatorStats* stats); +bool TpuExecutor_DeviceMemoryUsage(SE_StreamExecutor* executor, int64_t* free, + int64_t* total); + +bool TpuExecutor_AllocateStream(SE_StreamExecutor* executor, SE_Stream* stream); +void TpuExecutor_DeallocateStream(SE_StreamExecutor* executor, + SE_Stream* stream); +bool TpuExecutor_CreateStreamDependency(SE_StreamExecutor* executor, + SE_Stream* dependent, SE_Stream* other); +void TpuExecutor_GetStatus(SE_StreamExecutor* executor, SE_Stream* stream, + SE_Status* status); + +void TpuExecutor_AllocateEvent(SE_StreamExecutor* executor, SE_Event* event, + SE_Status* status); +void TpuExecutor_DeallocateEvent(SE_StreamExecutor* executor, SE_Event* event, + SE_Status* status); +int TpuExecutor_PollForEventStatus(SE_StreamExecutor* executor, + SE_Event* event); +void TpuExecutor_RecordEvent(SE_StreamExecutor* executor, SE_Stream* stream, + SE_Event* event, SE_Status* status); +void TpuExecutor_WaitForEvent(SE_StreamExecutor* executor, SE_Stream* stream, + SE_Event* event, SE_Status* status); + +bool TpuExecutor_AllocateTimer(SE_StreamExecutor* executor, SE_Timer* timer); +void TpuExecutor_DeallocateTimer(SE_StreamExecutor* executor, SE_Timer* timer); +bool TpuExecutor_StartTimer(SE_StreamExecutor* executor, SE_Stream* stream, + SE_Timer* timer); +bool TpuExecutor_StopTimer(SE_StreamExecutor* executor, SE_Stream* stream, + SE_Timer* timer); + +void TpuExecutor_SynchronousMemcpyToHost(SE_StreamExecutor* executor, + void* host_dst, + const SE_DeviceMemoryBase* device_src, + uint64_t size, SE_Status* status); +void TpuExecutor_SynchronousMemcpyFromHost(SE_StreamExecutor* executor, + SE_DeviceMemoryBase* device_dst, + const void* host_src, uint64_t size, + SE_Status* status); +bool TpuExecutor_MemcpyToHost(SE_StreamExecutor* executor, SE_Stream* stream, + void* host_dst, + const SE_DeviceMemoryBase* device_src, + uint64_t size); + +bool TpuExecutor_MemcpyFromHost(SE_StreamExecutor* executor, SE_Stream* stream, + SE_DeviceMemoryBase* device_dst, + const void* host_src, uint64_t size); + +void TpuExecutor_EnqueueInfeed(SE_StreamExecutor* executor, + int32_t infeed_queue_index, const uint8_t* data, + int64_t size, SE_Status* status); +void TpuExecutor_DequeueOutfeed(SE_StreamExecutor* executor, + int32_t outfeed_queue_index, uint8_t* data, + int64_t size, SE_Status* status); +void TpuExecutor_WaitForInfeedReady(SE_StreamExecutor* executor, + int32_t infeed_queue_index, + SE_Status* status); +void TpuExecutor_WaitForOutfeedReady(SE_StreamExecutor* executor, + int32_t outfeed_queue_index, + SE_Status* status); + +void TpuExecutor_BlockHostUntilDone(SE_StreamExecutor* executor, + SE_Stream* stream, SE_Status* status); +void TpuExecutor_BlockUntilDoneOrFailed(SE_StreamExecutor* executor, + SE_Status* status); +void TpuExecutor_SyncAndForgetFailedStreams(SE_StreamExecutor* executor); +bool TpuExecutor_SynchronizeAllActivity(SE_StreamExecutor* executor); + +SE_Stream* TpuStream_New(SE_StreamExecutor* parent); +void TpuStream_Free(SE_Stream*); +void* TpuStream_Stream(SE_Stream*); +bool TpuStream_Status(SE_Stream*); + +SE_Event* TpuEvent_New(SE_StreamExecutor* parent); +void TpuEvent_Free(SE_Event*); + +SE_Timer* TpuTimer_New(SE_StreamExecutor* parent); +void TpuTimer_Free(SE_Timer*); +int64_t TpuTimer_Nanoseconds(SE_Timer*); +int64_t TpuTimer_Microseconds(SE_Timer*); + +SE_Status* TpuStatus_New(); +SE_Status* TpuStatus_Create(int32_t code, const char* msg); +void TpuStatus_Free(SE_Status* status); +const char* TpuStatus_Message(SE_Status* status); +int TpuStatus_Code(SE_Status* status); +bool TpuStatus_Ok(SE_Status* status); + +SE_StreamExecutorConfig* TpuStreamExecutorConfig_Default(); +void TpuStreamExecutorConfig_SetOrdinal(SE_StreamExecutorConfig*, int ordinal); +void TpuStreamExecutorConfig_Free(SE_StreamExecutorConfig*); + +SE_DeviceDescription* TpuDeviceDescription_New(); +void TpuDeviceDescription_Free(SE_DeviceDescription* description); +void TpuExecutor_CreateDeviceDescription(SE_StreamExecutor* executor, + SE_DeviceDescription* description, + SE_Status* status); + +SE_DeviceOptions* TpuExecutor_NewDeviceOptions(unsigned flags); +void TpuExecutor_FreeDeviceOptions(SE_DeviceOptions* options); + +bool TpuExecutor_HostCallback(SE_StreamExecutor* executor, SE_Stream* stream, + SE_StatusCallbackFn callback_fn, void* ctx); + +XLA_TransferManager* TpuTransferManager_New(); +void TpuTransferManager_Free(XLA_TransferManager* manager); +SE_PlatformId TpuTransferManager_PlatformId(XLA_TransferManager* manager); +void TpuTransferManager_HostShapeToDeviceShape(XLA_TransferManager* manager, + XLA_Shape* host_shape, + XLA_Shape* device_shape); +void TpuTransferManager_TransferLiteralToDeviceAsync( + XLA_TransferManager* manager, SE_Stream* stream, XLA_Literal* literal, + XLA_ShapedBuffer* device_buffer, SE_Status* status); +void TpuTransferManager_TransferLiteralFromDevice( + XLA_TransferManager* manager, SE_Stream* stream, + XLA_ShapedBuffer* device_buffer, XLA_Literal* literal, + XLA_StatusCallbackFn callback, void* ctx); + +int64_t TpuTransferManager_GetByteSizeRequirement(XLA_TransferManager* manager, + XLA_Shape* shape); +void TpuTransferManager_WriteSingleTupleIndexTable( + XLA_TransferManager* manager, SE_Stream* stream, + SE_DeviceMemoryBase* elements, size_t elements_len, XLA_Shape* shape, + SE_DeviceMemoryBase* region, SE_Status* status); + +XLA_ComputationPlacer* TpuComputationPlacer_New(); +void TpuComputationPlacer_Free(XLA_ComputationPlacer* placer); +} + +// extern "C" + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_C_API_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_executor_interface.h b/tensorflow/stream_executor/tpu/tpu_executor_interface.h new file mode 100644 index 00000000000..5b00f615ca7 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_executor_interface.h @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_ + +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/stream_executor/stream_executor_internal.h" +#include "tensorflow/stream_executor/timer.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" + +namespace tpu { +class TpuCore; +} // namespace tpu + +namespace tensorflow { +namespace tpu { + +class TpuExecutorInterface + : public ::stream_executor::internal::StreamExecutorInterface { + public: + using Status = ::stream_executor::port::Status; + template + using StatusOr = ::stream_executor::port::StatusOr; + + class TemporaryDeviceMemory { + public: + virtual ~TemporaryDeviceMemory() {} + virtual stream_executor::DeviceMemoryBase AsDeviceMemoryBase() const = 0; + }; + + virtual StatusOr> + CreateTemporaryDeviceMemory(int64 memory_space, int64 byte_offset, + int64 size) { + LOG(FATAL) << "Unimplemented."; + } + + virtual const TpuPlatformInterface& platform() const { + LOG(FATAL) << "Unimplemented."; + } + + virtual TpuPlatformInterface& platform() { LOG(FATAL) << "Unimplemented."; } +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_INTERFACE_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_node_context.cc b/tensorflow/stream_executor/tpu/tpu_node_context.cc new file mode 100644 index 00000000000..2a4954d5b08 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_node_context.cc @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/stream_executor/tpu/tpu_node_context.h" + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" +#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h" + +namespace tensorflow { +namespace tpu { + +using stream_executor::port::Status; +using stream_executor::port::StatusOr; + +/*static*/ StatusOr> TpuNodeContext::Initialize( + int device_ordinal) { + StatusHelper status; + XLA_TpuNodeContext* node_context = + TpuNodeContext_Create(device_ordinal, status.c_status); + if (!status.status().ok()) { + TpuNodeContext_Free(node_context); + return status.status(); + } + return std::make_unique(device_ordinal, node_context); +} + +TpuNodeContext::~TpuNodeContext() { TpuNodeContext_Free(node_context_); } + +/* static */ +Status TpuNodeContext::StopChipHeartbeats() { + StatusHelper status; + TpuNodeContext_StopChipHeartbeats(status.c_status); + return status.status(); +} + +/* static */ +Status TpuNodeContext::CloseTpuHost() { + StatusHelper status; + TpuNodeContext_CloseTpuHost(status.c_status); + return status.status(); +} + +/* static */ +tensorflow::tpu::TpuPlatformInterface* TpuNodeContext::platform() { + return TpuPlatformInterface::GetRegisteredPlatform(); +} + +/* static */ +stream_executor::DeviceMemoryAllocator* TpuNodeContext::memory_allocator() { + static stream_executor::StreamExecutorMemoryAllocator* memory_allocator = + new stream_executor::StreamExecutorMemoryAllocator( + platform(), + xla::PlatformUtil::GetStreamExecutors(platform()).ValueOrDie()); + return memory_allocator; +} + +/* static */ +xla::Backend* TpuNodeContext::backend() { + static xla::Backend* backend = + xla::Backend::CreateBackend( + xla::BackendOptions().set_platform(platform())) + .ValueOrDie() + .release(); + return backend; +} + +/* static */ +StatusOr TpuNodeContext::BorrowStream( + int device_ordinal) { + return backend()->BorrowStream(device_ordinal); +} + +/* static */ +StatusOr TpuNodeContext::BorrowStream( + stream_executor::StreamExecutor* executor) { + return backend()->BorrowStream(executor); +} + +/* static */ +xla::TransferManager* TpuNodeContext::transfer_manager() { + return xla::TransferManager::GetForPlatform(platform()).ValueOrDie(); +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/stream_executor/tpu/tpu_node_context.h b/tensorflow/stream_executor/tpu/tpu_node_context.h new file mode 100644 index 00000000000..e1e1ffc67f7 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_node_context.h @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_ + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/tpu/status_helper.h" +#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" + +namespace tensorflow { +namespace tpu { + +class TpuNodeContext final { + public: + using Status = stream_executor::port::Status; + template + using StatusOr = stream_executor::port::StatusOr; + + static StatusOr> Initialize( + int device_ordinal); + + explicit TpuNodeContext(int device_ordinal, XLA_TpuNodeContext* node_context) + : device_ordinal_(device_ordinal), node_context_(node_context) { + CHECK_NE(node_context, nullptr); + } + ~TpuNodeContext(); + + TpuNodeContext(const TpuNodeContext&) = delete; + TpuNodeContext& operator=(const TpuNodeContext&) = delete; + + static Status StopChipHeartbeats(); + + static Status CloseTpuHost(); + + static tensorflow::tpu::TpuPlatformInterface* platform(); + + static stream_executor::DeviceMemoryAllocator* memory_allocator(); + + static xla::TransferManager* transfer_manager(); + + static xla::Backend* backend(); + + static StatusOr BorrowStream(int device_ordinal); + + static StatusOr BorrowStream( + stream_executor::StreamExecutor* executor); + + stream_executor::StreamExecutor* stream_executor() { + LOG(FATAL) << "Not implemented yet."; + } + + std::string tensor_core_location() { LOG(FATAL) << "Not implemented yet."; } + + int index_on_host() { LOG(FATAL) << "Not implemented yet."; } + + int device_ordinal() const { return device_ordinal_; } + + private: + const int device_ordinal_; + XLA_TpuNodeContext* const node_context_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_node_context_c_api.h b/tensorflow/stream_executor/tpu/tpu_node_context_c_api.h new file mode 100644 index 00000000000..d2684e47df1 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_node_context_c_api.h @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_ + +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" + +typedef struct XLA_TpuNodeContext XLA_TpuNodeContext; + +XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal, + SE_Status* status); +void TpuNodeContext_Free(XLA_TpuNodeContext* node_context); + +void TpuNodeContext_StopChipHeartbeats(SE_Status* status); +void TpuNodeContext_CloseTpuHost(SE_Status* status); + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_platform.cc b/tensorflow/stream_executor/tpu/tpu_platform.cc new file mode 100644 index 00000000000..c44926df749 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_platform.cc @@ -0,0 +1,125 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/tpu/tpu_platform.h" + +#include "tensorflow/c/tf_status.h" +#include "tensorflow/stream_executor/platform.h" +#include "tensorflow/stream_executor/tpu/status_helper.h" +#include "tensorflow/stream_executor/tpu/tpu_executor.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" + +namespace tensorflow { + +PLATFORM_DEFINE_ID(TpuPlatform::kId); +TpuPlatform* tpu_registered_platform = nullptr; + +using Status = ::stream_executor::port::Status; +template +using StatusOr = ::stream_executor::port::StatusOr; + +TpuPlatform::TpuPlatform() { platform_ = TpuPlatform_New(); } + +TpuPlatform* TpuPlatform::GetRegisteredPlatform() { + return tpu_registered_platform; +} + +Status TpuPlatform::Initialize( + const std::map& platform_options) { + StatusHelper status; + + size_t options_size = platform_options.size(); + const char** options_key = + static_cast(malloc(sizeof(const char*) * options_size)); + const char** options_value = + static_cast(malloc(sizeof(const char*) * options_size)); + + size_t i = 0; + for (const auto& option : platform_options) { + options_key[i] = option.first.c_str(); + options_value[i] = option.second.c_str(); + i++; + } + + TpuPlatform_Initialize(platform_, options_size, options_key, options_value, + status.c_status); + + free(options_key); + free(options_value); + + return status.status(); +} + +TpuPlatform::~TpuPlatform() { TpuPlatform_Free(platform_); } + +int TpuPlatform::VisibleDeviceCount() const { + return TpuPlatform_VisibleDeviceCount(platform_); +} + +StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor( + const ::stream_executor::StreamExecutorConfig& config) { + return executor_cache_.GetOrCreate( + config, [&]() { return GetUncachedExecutor(config); }); +} + +StatusOr> +TpuPlatform::GetUncachedExecutor( + const ::stream_executor::StreamExecutorConfig& config) { + SE_StreamExecutorConfig* c_config = TpuStreamExecutorConfig_Default(); + + TpuStreamExecutorConfig_SetOrdinal(c_config, config.ordinal); + + StatusHelper status; + SE_StreamExecutor* executor = + TpuPlatform_GetExecutor(platform_, c_config, status.c_status); + TpuStreamExecutorConfig_Free(c_config); + if (!status.ok()) { + return status.status(); + } + return std::make_unique( + this, absl::make_unique(this, executor), + config.ordinal); +} + +::stream_executor::Platform::Id TpuPlatform::id() const { + return TpuPlatform::kId; +} + +const std::string& TpuPlatform::Name() const { + static std::string* name = new std::string(kName); + return *name; +} + +int64 TpuPlatform::TpuMemoryLimit() { + return TpuPlatform_TpuMemoryLimit(platform_); +} + +} // namespace tensorflow + +void RegisterTpuPlatform() { + tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform(); + std::unique_ptr platform( + tensorflow::tpu_registered_platform); + SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( + std::move(platform))); +} + +REGISTER_MODULE_INITIALIZER(tpu_platform, RegisterTpuPlatform()); + +// Note that module initialization sequencing is not supported in the +// open-source project, so this will be a no-op there. +REGISTER_MODULE_INITIALIZER_SEQUENCE(tpu_platform, multi_platform_manager); +REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, + tpu_platform); diff --git a/tensorflow/stream_executor/tpu/tpu_platform.h b/tensorflow/stream_executor/tpu/tpu_platform.h new file mode 100644 index 00000000000..9a67045dec6 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_platform.h @@ -0,0 +1,121 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/executor_cache.h" +#include "tensorflow/stream_executor/platform.h" +#include "tensorflow/stream_executor/stream_executor_internal.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" + +namespace tensorflow { + +class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface { + public: + using StreamMap = + absl::flat_hash_map; + + static const ::stream_executor::Platform::Id kId; + static constexpr char kName[] = "TPU"; + + using Status = ::stream_executor::port::Status; + template + using StatusOr = ::stream_executor::port::StatusOr; + + TpuPlatform(); + + ~TpuPlatform() override; + + static TpuPlatform* GetRegisteredPlatform(); + + Id id() const override; + + const std::string& Name() const override; + + int VisibleDeviceCount() const override; + + int64 TpuMemoryLimit() override; + + bool Initialized() const override { + return TpuPlatform_Initialized(platform_); + } + + Status Initialize( + const std::map& platform_options) override; + + Status Reset() override { return Reset(false); } + + Status Reset(bool only_tear_down) override { + LOG(FATAL) << "Not yet implemented"; + } + + StatusOr> + DescriptionForDevice(int ordinal) const override { + LOG(FATAL) << "Not yet implemented"; + } + + StatusOr<::stream_executor::StreamExecutor*> ExecutorForDevice( + int ordinal) override { + stream_executor::StreamExecutorConfig config; + config.ordinal = ordinal; + return GetExecutor(config); + } + + StatusOr<::stream_executor::StreamExecutor*> + ExecutorForDeviceWithPluginConfig( + int ordinal, + const ::stream_executor::PluginConfig& plugin_config) override { + stream_executor::StreamExecutorConfig config; + config.ordinal = ordinal; + config.plugin_config = plugin_config; + return GetExecutor(config); + } + + StatusOr<::stream_executor::StreamExecutor*> GetExecutor( + const ::stream_executor::StreamExecutorConfig& config) override; + + StatusOr> + GetUncachedExecutor( + const ::stream_executor::StreamExecutorConfig& config) override; + + void RegisterTraceListener( + std::unique_ptr listener) override { + LOG(FATAL) << "Not yet implemented"; + } + + void UnregisterTraceListener( + stream_executor::TraceListener* listener) override { + LOG(FATAL) << "Not yet implemented"; + } + + StreamMap* stream_map() { return &stream_map_; } + + private: + SE_Platform* platform_; + + stream_executor::ExecutorCache executor_cache_; + StreamMap stream_map_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc new file mode 100644 index 00000000000..c5b8ece32af --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" + +#include "tensorflow/stream_executor/multi_platform_manager.h" + +namespace tensorflow { +namespace tpu { + +/* static */ +TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() { + // Prefer TpuPlatform if it's registered. + auto status_or_tpu_platform = + stream_executor::MultiPlatformManager::PlatformWithName("TPU"); + if (status_or_tpu_platform.ok()) { + return static_cast( + status_or_tpu_platform.ValueOrDie()); + } + if (status_or_tpu_platform.status().code() != error::NOT_FOUND) { + LOG(WARNING) << "Error when getting the TPU platform: " + << status_or_tpu_platform.status(); + return nullptr; + } + + // Use any other registered TPU platform. + auto status_or_other_tpu_platforms = + stream_executor::MultiPlatformManager::PlatformsWithFilter( + [](const stream_executor::Platform* platform) { + return dynamic_cast(platform) != + nullptr; + }); + if (!status_or_other_tpu_platforms.ok()) { + LOG(WARNING) << "Error when getting other TPU platforms: " + << status_or_tpu_platform.status(); + return nullptr; + } + auto other_tpu_platforms = status_or_other_tpu_platforms.ValueOrDie(); + if (!other_tpu_platforms.empty()) { + LOG(WARNING) << other_tpu_platforms.size() + << " TPU platforms registered, selecting " + << other_tpu_platforms[0]->Name(); + return static_cast(other_tpu_platforms[0]); + } + + LOG(WARNING) << "No TPU platform registered"; + return nullptr; +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.h b/tensorflow/stream_executor/tpu/tpu_platform_interface.h new file mode 100644 index 00000000000..5c7aa8efe94 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_ + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/platform.h" + +namespace tensorflow { +namespace tpu { + +class TpuPlatformInterface : public stream_executor::Platform { + public: + using Status = stream_executor::port::Status; + + // Returns a TPU platform to be used by TPU ops. If multiple TPU platforms are + // registered, finds the most suitable one. Returns nullptr if no TPU platform + // is registered or an error occurred. + static TpuPlatformInterface* GetRegisteredPlatform(); + + virtual Status Reset() { return Reset(false); } + + virtual Status Reset(bool only_tear_down) = 0; + + virtual int64 TpuMemoryLimit() = 0; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_INTERFACE_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_stream.h b/tensorflow/stream_executor/tpu/tpu_stream.h new file mode 100644 index 00000000000..b8fd10df5d9 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_stream.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_ + +#include "tensorflow/stream_executor/stream_executor_internal.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" + +class TpuStream : public stream_executor::internal::StreamInterface { + public: + explicit TpuStream(SE_Stream* stream) : stream_(stream) {} + ~TpuStream() override { TpuStream_Free(stream_); } + + private: + SE_Stream* stream_; +}; + +class TpuEvent : public ::stream_executor::internal::EventInterface { + public: + explicit TpuEvent(SE_Event* event) : event_(event) {} + ~TpuEvent() override { TpuEvent_Free(event_); } + + private: + SE_Event* event_; +}; + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_stream_interface.h b/tensorflow/stream_executor/tpu/tpu_stream_interface.h new file mode 100644 index 00000000000..2e5e02ded7d --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_stream_interface.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_ + +#include "tensorflow/stream_executor/stream_executor_internal.h" + +namespace tensorflow { +namespace tpu { + +class TpuStreamInterface : public ::stream_executor::internal::StreamInterface { +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_STREAM_INTERFACE_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_timer.h b/tensorflow/stream_executor/tpu/tpu_timer.h new file mode 100644 index 00000000000..246a0b7eb32 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_timer.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_ + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/stream_executor_internal.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" + +namespace tensorflow { + +class TpuTimer : public ::stream_executor::internal::TimerInterface { + public: + explicit TpuTimer(SE_Timer* timer) : timer_(timer) {} + ~TpuTimer() override { TpuTimer_Free(timer_); } + uint64 Microseconds() const override { return TpuTimer_Microseconds(timer_); } + uint64 Nanoseconds() const override { return TpuTimer_Nanoseconds(timer_); } + + private: + SE_Timer* timer_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TIMER_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc b/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc new file mode 100644 index 00000000000..473585a12d1 --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/tpu/c_api_conversions.h" +#include "tensorflow/stream_executor/tpu/proto_helper.h" +#include "tensorflow/stream_executor/tpu/status_helper.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" +#include "tensorflow/stream_executor/tpu/tpu_platform.h" + +namespace tensorflow { + +using Status = stream_executor::port::Status; +template +using StatusOr = stream_executor::port::StatusOr; + +TpuTransferManager::TpuTransferManager() { + manager_ = TpuTransferManager_New(); +} + +TpuTransferManager::~TpuTransferManager() { TpuTransferManager_Free(manager_); } + +stream_executor::Platform::Id TpuTransferManager::PlatformId() const { + return TpuPlatform::kId; +} + +xla::Shape TpuTransferManager::HostShapeToDeviceShape( + const xla::Shape& host_shape) const { + XLA_Shape c_host_shape; + XLA_Shape c_device_shape; + + TpuConversions::XlaShapeToCShape(host_shape, &c_host_shape); + + TpuTransferManager_HostShapeToDeviceShape(manager_, &c_host_shape, + &c_device_shape); + xla::Shape device_shape = TpuConversions::CShapeToXlaShape(&c_device_shape); + TpuConversions::CShapeCleanup(&c_host_shape); + TpuConversions::CShapeCleanup(&c_device_shape); + return device_shape; +} + +Status TpuTransferManager::TransferLiteralToDeviceAsync( + stream_executor::Stream* stream, const xla::LiteralSlice& literal, + const xla::ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata) { + StatusHelper status; + + XLA_Literal c_literal; + TpuConversions::XLALiteralToCLiteral(literal, &c_literal); + + XLA_ShapedBuffer c_device_buffer; + TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer, + &c_device_buffer); + + TpuTransferManager_TransferLiteralToDeviceAsync( + manager_, + TpuPlatform::GetRegisteredPlatform()->stream_map()->at( + stream->implementation()), + &c_literal, &c_device_buffer, status.c_status); + TpuConversions::CShapedBufferCleanup(&c_device_buffer); + TpuConversions::CLiteralCleanup(&c_literal); + return status.status(); +} + +struct TransferFromDeviceState { + std::atomic remaining_transfers; + StatusHelper status_helper; + std::function done; + + void TransferFinished(SE_Status* status) { + if (!TpuStatus_Ok(status) && TpuStatus_Ok(status_helper.c_status)) { + status_helper.c_status = status; + } else { + TpuStatus_Free(status); + } + + if (--remaining_transfers == 0) { + done(status_helper.status()); + delete this; + } + } +}; + +void TransferLiteralFromDeviceTrampoline(void* ctx, SE_Status* status) { + reinterpret_cast(ctx)->TransferFinished(status); +} + +void TpuTransferManager::TransferLiteralFromDevice( + stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer, + xla::MutableBorrowingLiteral literal, std::function done, + const TransferMetadata* transfer_metadata) { + TransferFromDeviceState* state = new TransferFromDeviceState; + state->remaining_transfers = 1; + state->done = done; + XLA_ShapedBuffer c_device_buffer; + TpuConversions::XLAShapedBufferToCShapedBuffer(device_buffer, + &c_device_buffer); + XLA_Literal c_literal; + TpuConversions::XLALiteralToCLiteral(literal, &c_literal); + + TpuTransferManager_TransferLiteralFromDevice( + manager_, + TpuPlatform::GetRegisteredPlatform()->stream_map()->at( + stream->implementation()), + &c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state); + TpuConversions::CShapedBufferCleanup(&c_device_buffer); + TpuConversions::CLiteralCleanup(&c_literal); +} + +int64 TpuTransferManager::GetByteSizeRequirement( + const xla::Shape& shape) const { + XLA_Shape c_shape; + TpuConversions::XlaShapeToCShape(shape, &c_shape); + + int64 size_in_bytes = + TpuTransferManager_GetByteSizeRequirement(manager_, &c_shape); + + TpuConversions::CShapeCleanup(&c_shape); + return size_in_bytes; +} + +Status TpuTransferManager::WriteSingleTupleIndexTable( + stream_executor::Stream* stream, + absl::Span elements, + const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) { + CHECK_GT(elements.size(), 0); + SE_DeviceMemoryBase* elements_bases = + new SE_DeviceMemoryBase[elements.size()]; + for (int i = 0; i < elements.size(); i++) { + elements_bases[i] = + SE_DeviceMemoryBase{const_cast(elements[i].opaque()), + elements[i].size(), elements[i].payload()}; + } + XLA_Shape c_shape; + TpuConversions::XlaShapeToCShape(shape, &c_shape); + SE_DeviceMemoryBase region_base{region->opaque(), region->size(), + region->payload()}; + StatusHelper status; + + TpuTransferManager_WriteSingleTupleIndexTable( + manager_, + TpuPlatform::GetRegisteredPlatform()->stream_map()->at( + stream->implementation()), + elements_bases, elements.size(), &c_shape, ®ion_base, status.c_status); + + delete[] elements_bases; + TpuConversions::CShapeCleanup(&c_shape); + return status.status(); +} + +} // namespace tensorflow diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager.h b/tensorflow/stream_executor/tpu/tpu_transfer_manager.h new file mode 100644 index 00000000000..163ac81ea5f --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager.h @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_ + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" + +namespace tensorflow { + +class TpuTransferManager : public xla::TransferManager { + public: + TpuTransferManager(); + ~TpuTransferManager() override; + + using Status = stream_executor::port::Status; + template + using StatusOr = stream_executor::port::StatusOr; + + stream_executor::Platform::Id PlatformId() const override; + + xla::Shape HostShapeToDeviceShape( + const xla::Shape& host_shape) const override; + + Status TransferLiteralToDeviceAsync( + stream_executor::Stream* stream, const xla::LiteralSlice& literal, + const xla::ShapedBuffer& device_buffer, + const TransferMetadata* transfer_metadata) override; + + void TransferLiteralFromDevice( + stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer, + xla::MutableBorrowingLiteral literal, std::function done, + const TransferMetadata* transfer_metadata) override; + + Status TransferLiteralToInfeed(stream_executor::StreamExecutor* executor, + const xla::LiteralSlice& literal) override { + LOG(FATAL) << "Not yet implemented"; + } + + Status TransferLiteralFromOutfeed( + stream_executor::StreamExecutor* executor, + const xla::Shape& literal_shape, + xla::MutableBorrowingLiteral literal) override { + LOG(FATAL) << "Not yet implemented"; + } + + Status ResetDevices( + absl::Span executor) override { + LOG(FATAL) << "Not yet implemented"; + } + + int64 GetByteSizeRequirement(const xla::Shape& shape) const override; + + Status WriteSingleTupleIndexTable( + stream_executor::Stream* stream, + absl::Span elements, + const xla::Shape& shape, + stream_executor::DeviceMemoryBase* region) override; + + private: + XLA_TransferManager* manager_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc b/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc new file mode 100644 index 00000000000..f7f0c6fbe2c --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/stream_executor/tpu/tpu_platform.h" +#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h" + +namespace tensorflow { + +static std::unique_ptr CreateTpuTransferManager() { + return std::make_unique(); +} + +static bool InitModule() { + xla::TransferManager::RegisterTransferManager(TpuPlatform::kId, + CreateTpuTransferManager); + return true; +} +static bool module_initialized = InitModule(); + +} // namespace tensorflow