diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD index 69238456d57..bffb44c1b97 100644 --- a/tensorflow/core/tpu/graph_rewrite/BUILD +++ b/tensorflow/core/tpu/graph_rewrite/BUILD @@ -13,6 +13,7 @@ cc_library( srcs = ["tpu_rewrite_pass_registration.cc"], deps = [ ":distributed_tpu_configuration_rewrite_pass", + ":distributed_tpu_rewrite_pass", ":encapsulate_tpu_computations_pass", ":variable_merger_pass", "//tensorflow/core:core_cpu", @@ -101,3 +102,120 @@ cc_library( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "distributed_tpu_rewrite_pass_internal", + srcs = ["distributed_tpu_rewrite_pass_internal.cc"], + hdrs = ["distributed_tpu_rewrite_pass_internal.h"], + deps = [ + "//tensorflow/core:framework", + "@com_google_absl//absl/random", + ], +) + +cc_library( + name = "distributed_tpu_rewrite_pass", + srcs = [ + "distributed_tpu_rewrite_pass.cc", + ], + hdrs = [ + "distributed_tpu_rewrite_pass.h", + ], + deps = [ + ":cond_builder", + ":distributed_tpu_rewrite_helpers", + ":distributed_tpu_rewrite_pass_internal", + ":host_training_loop_optimization_util", + ":incomplete_nodedef_builder", + "//tensorflow/compiler/jit:encapsulate_util", + "//tensorflow/compiler/jit:shape_inference", + "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/tf2xla:sharding_util", + "//tensorflow/compiler/tf2xla:side_effect_util", + "//tensorflow/compiler/tf2xla:tf2xla_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla/client:sharding_builder", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core/common_runtime:function", + "//tensorflow/core/common_runtime:graph_constructor", + "//tensorflow/core/common_runtime:lower_function_call_op", + "//tensorflow/core/common_runtime:lower_functional_ops", + "//tensorflow/core/common_runtime:lower_if_op", + "//tensorflow/core/common_runtime:lower_while_op", + "//tensorflow/core/common_runtime:optimization_registry", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", + "//tensorflow/core/protobuf/tpu:topology_proto_cc", + "//tensorflow/core/tpu:tpu_compile_interface", + "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs", + "//tensorflow/stream_executor/tpu:tpu_platform_interface", + "//tensorflow/stream_executor/tpu:tpu_topology_external", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "incomplete_nodedef_builder", + srcs = ["incomplete_nodedef_builder.cc"], + hdrs = ["incomplete_nodedef_builder.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "cond_builder", + srcs = ["cond_builder.cc"], + hdrs = ["cond_builder.h"], + deps = [ + ":incomplete_nodedef_builder", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "host_training_loop_optimization_util", + srcs = [ + "host_training_loop_optimization_util.cc", + ], + hdrs = [ + "host_training_loop_optimization_util.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":distributed_tpu_rewrite_pass_internal", + "//tensorflow/compiler/tf2xla:functionalize_control_flow_util", + "//tensorflow/compiler/tf2xla:tf2xla_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/tensorflow/core/tpu/graph_rewrite/cond_builder.cc b/tensorflow/core/tpu/graph_rewrite/cond_builder.cc new file mode 100644 index 00000000000..e16ae08aec3 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/cond_builder.cc @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/graph_rewrite/cond_builder.h" + +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h" + +namespace tensorflow { + +CondBuilder::CondBuilder(string name, string device, const NodeDebugInfo& debug, + Graph* graph) + : graph_(graph), name_(std::move(name)), device_(std::move(device)) { + auto new_name = [graph, this](string suffix) { + return graph->NewName(strings::StrCat(name_, "/", suffix)); + }; + TF_CHECK_OK( + IncompleteNodeDefBuilder::Identity(new_name("pred"), DT_BOOL, debug) + .Device(device_) + .Build(graph_, &pred_)); + Node* switch_pred; + TF_CHECK_OK( + IncompleteNodeDefBuilder::Switch(new_name("switch_pred"), DT_BOOL, debug) + .Device(device_) + .Build(graph_, &switch_pred)); + graph_->AddEdge(pred(), 0, switch_pred, 0); + graph_->AddEdge(pred(), 0, switch_pred, 1); + TF_CHECK_OK( + IncompleteNodeDefBuilder::Identity(new_name("switch_f"), DT_BOOL, debug) + .Device(device_) + .Build(graph_, &switch_f_)); + TF_CHECK_OK( + IncompleteNodeDefBuilder::Identity(new_name("switch_t"), DT_BOOL, debug) + .Device(device_) + .Build(graph_, &switch_t_)); + graph_->AddEdge(switch_pred, kElseBranch, switch_f_, 0); + graph_->AddEdge(switch_pred, kThenBranch, switch_t_, 0); + Node* merge_pred; + TF_CHECK_OK(IncompleteNodeDefBuilder::Merge(new_name("merge_pred"), DT_BOOL, + debug, /*n=*/2) + .Device(device_) + .Build(graph_, &merge_pred)); + graph_->AddEdge(switch_f_, 0, merge_pred, kElseBranch); + graph_->AddEdge(switch_t_, 0, merge_pred, kThenBranch); + // Note: when additional return values are added then there should be a + // control dependency between those merge nodes and control_successor_ to + // ensure that it is control successor of conditional. + control_successor_ = merge_pred; +} + +Node* CondBuilder::pred() { return pred_; } + +Node* CondBuilder::switch_f() { return switch_f_; } + +Node* CondBuilder::switch_t() { return switch_t_; } + +Node* CondBuilder::control_successor() { return control_successor_; } + +Status CondBuilder::AddInput(const string& input_name, const DataType& type, + const string& device, const NodeDebugInfo& debug, + Node** input) { + auto b = IncompleteNodeDefBuilder::Switch( + graph_->NewName(strings::StrCat(name_, "/", input_name)), type, debug); + TF_RETURN_IF_ERROR(b.Device(device).Build(graph_, input)); + graph_->AddEdge(pred(), 0, *input, 1); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/graph_rewrite/cond_builder.h b/tensorflow/core/tpu/graph_rewrite/cond_builder.h new file mode 100644 index 00000000000..29e264dfc0a --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/cond_builder.h @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_ + +#include <string> + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Conditional builder. +// Convenience builder to make it easy to construct a conditional. E.g., +// Node* pred = ...; +// CondBuilder cb("cond", g); +// auto switch_var = cb.AddInput("var", DT_RESOURCE); +// g->AddEdge(pred, 0, cb.pred(), 0); +// Will create the nodes of a conditional that takes as input a resource +// variable ("var") as input and that switches on pred. +// +// This currently only handles the case needed by distributed_tpu_rewrite_pass +// and is not completely general. +class CondBuilder { + public: + enum Branch { kElseBranch = 0, kThenBranch = 1 }; + + CondBuilder(string name, string device, const NodeDebugInfo& debug, + Graph* graph); + + // Returns node corresponding to the predicate input. + Node* pred(); + + // Returns node corresponding to switch_f branch of predicate switch. + Node* switch_f(); + + // Returns node corresponding to switch_t branch of predicate switch. + Node* switch_t(); + + // Returns node corresponding to control successor. + Node* control_successor(); + + // Returns the Switch node to feed a value of the given type into the + // conditional. + Status AddInput(const string& input_name, const DataType& type, + const string& device, const NodeDebugInfo& debug, + Node** input); + + private: + Node* control_successor_; + Node* switch_f_; + Node* switch_t_; + Node* pred_; + Graph* const graph_; + const string name_; + const string device_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_ diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc new file mode 100644 index 00000000000..f0032f5dfd9 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -0,0 +1,4105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Compilation for distributed TPU (TPU_REPLICATED_CORE devices). + +#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h" + +#include <queue> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/sharding_builder.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/lower_function_call_op.h" +#include "tensorflow/core/common_runtime/lower_functional_ops.h" +#include "tensorflow/core/common_runtime/lower_if_op.h" +#include "tensorflow/core/common_runtime/lower_while_op.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/tpu/graph_rewrite/cond_builder.h" +#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h" +#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h" +#include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h" +#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h" +#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h" +#include "tensorflow/core/tpu/tpu_compile_interface.h" +#include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" + +namespace tensorflow { + +namespace { + +// Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4 +// topology. +constexpr int kTPUTopologyRank = 4; + +// An upper bound on how many cores may be present in the topology. +static constexpr int kTPUMaxTopologySize = 4096; + +// Attribute containing the serialized xla::OpSharding to be passed to the +// corresponding XLA HLO operation, which represents how a shape is distributed +// across logical cores, e.g., replication, single-device, or partitioning. +const char kShardingAttribute[] = "_XlaSharding"; + +const char kTPUPartitionedInput[] = "TPUPartitionedInput"; +const char kTPUPartitionedOutput[] = "TPUPartitionedOutput"; + +static const char* const kTPUCompilationResultAttr = "_tpu_compilation_status"; +static const char* const kPostDeviceRewriteAttr = "_post_device_rewrite"; + +class IntrusiveHeapLink { + public: + using size_type = size_t; + static constexpr size_type kNotMember = -1; + + IntrusiveHeapLink() = default; + + // Only IntrusiveHeap and LinkAccess objects should make these objects. + explicit IntrusiveHeapLink(size_type pos) : pos_{pos} {} + + // Only IntrusiveHeap and LinkAccess should get the value. + size_type get() const { return pos_; } + + private: + size_type pos_{kNotMember}; +}; + +template <typename T, IntrusiveHeapLink T::*M> +struct IntrusiveHeapDataMemberLinkAccess { + IntrusiveHeapLink Get(const T* elem) const { return elem->*M; } + void Set(T* elem, IntrusiveHeapLink link) const { elem->*M = link; } +}; + +template <typename T> +struct DefaultIntrusiveHeapLinkAccess { + IntrusiveHeapLink Get(const T* elem) const { return elem->heap; } + void Set(T* elem, IntrusiveHeapLink link) const { elem->heap = link; } +}; + +template <typename T, typename PtrCompare, + typename LinkAccess = DefaultIntrusiveHeapLinkAccess<T>, + typename Alloc = std::allocator<T*>> +class IntrusiveHeap { + public: + typedef typename IntrusiveHeapLink::size_type size_type; + typedef T value_type; + typedef T* pointer; + typedef const T* const_pointer; + typedef PtrCompare pointer_compare_type; + typedef LinkAccess link_access_type; + typedef Alloc allocator_type; + + explicit IntrusiveHeap( + const pointer_compare_type& comp = pointer_compare_type(), + const link_access_type& link_access = link_access_type(), + const allocator_type& alloc = allocator_type()) + : rep_(comp, link_access, alloc) {} + + size_type size() const { return heap().size(); } + + bool empty() const { return heap().empty(); } + + // Return the top element, but don't remove it. + pointer top() const { + DCHECK(!empty()); + return heap()[0]; + } + + // Remove the top() pointer from the heap and return it. + pointer Pop() { + pointer t = top(); + Remove(t); + return t; + } + + // Insert 't' into the heap. + void Push(pointer t) { + SetPositionOf(t, heap().size()); + heap().push_back(t); + FixHeapUp(t); + } + + // Adjust the heap to accommodate changes in '*t'. + void Adjust(pointer t) { + DCHECK(Contains(t)); + size_type h = GetPositionOf(t); + if (h != 0 && compare()(t, heap()[(h - 1) >> 1])) { + FixHeapUp(t); + } else { + FixHeapDown(t); + } + } + + // Remove the specified pointer from the heap. + void Remove(pointer t) { + DCHECK(Contains(t)); + size_type h = GetPositionOf(t); + SetPositionOf(t, IntrusiveHeapLink::kNotMember); + if (h == heap().size() - 1) { + // Fast path for removing from back of heap. + heap().pop_back(); + return; + } + // Move the element from the back of the heap to overwrite 't'. + pointer& elem = heap()[h]; + elem = heap().back(); + SetPositionOf(elem, h); // Element has moved, so update its link. + heap().pop_back(); + Adjust(elem); // Restore the heap invariant. + } + + void Clear() { heap().clear(); } + + bool Contains(const_pointer t) const { + size_type h = GetPositionOf(t); + return (h != IntrusiveHeapLink::kNotMember) && (h < size()) && + heap()[h] == t; + } + + void reserve(size_type n) { heap().reserve(n); } + + size_type capacity() const { return heap().capacity(); } + + allocator_type get_allocator() const { return rep_.heap_.get_allocator(); } + + private: + typedef std::vector<pointer, allocator_type> heap_type; + + // Empty base class optimization for pointer_compare and link_access. + // The heap_ data member retains a copy of the allocator, so it is not + // stored explicitly. + struct Rep : pointer_compare_type, link_access_type { + explicit Rep(const pointer_compare_type& cmp, + const link_access_type& link_access, + const allocator_type& alloc) + : pointer_compare_type(cmp), + link_access_type(link_access), + heap_(alloc) {} + heap_type heap_; // NOLINT + }; + + const pointer_compare_type& compare() const { return rep_; } + + const link_access_type& link_access() const { return rep_; } + + const heap_type& heap() const { return rep_.heap_; } + heap_type& heap() { return rep_.heap_; } + + size_type GetPositionOf(const_pointer t) const { + return link_access().Get(t).get(); + } + + void SetPositionOf(pointer t, size_type pos) const { + return link_access().Set(t, IntrusiveHeapLink(pos)); + } + + void FixHeapUp(pointer t) { + size_type h = GetPositionOf(t); + while (h != 0) { + size_type parent = (h - 1) >> 1; + if (compare()(heap()[parent], t)) { + break; + } + heap()[h] = heap()[parent]; + SetPositionOf(heap()[h], h); + h = parent; + } + heap()[h] = t; + SetPositionOf(t, h); + } + + void FixHeapDown(pointer t) { + size_type h = GetPositionOf(t); + for (;;) { + size_type kid = (h << 1) + 1; + if (kid >= heap().size()) { + break; + } + if (kid + 1 < heap().size() && compare()(heap()[kid + 1], heap()[kid])) { + ++kid; + } + if (compare()(t, heap()[kid])) { + break; + } + heap()[h] = heap()[kid]; + SetPositionOf(heap()[h], h); + h = kid; + } + + heap()[h] = t; + SetPositionOf(t, h); + } + + Rep rep_; +}; + +string CoreDeviceLabel(int core) { + return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core); +} + +// Creates a unique node name with a particular prefix. +string UniqueNodeName(const StringPiece prefix, Graph* graph) { + return graph->NewName(strings::StrCat(prefix, "/_", internal::GetNodeId())); +} + +Status SetNodeDeviceForTPUCommunication(DeviceNameUtils::ParsedName device, + const string& target_device_type, + Node* node) { + TF_RET_CHECK(device.has_type && device.type == DEVICE_TPU_NODE); + TF_RET_CHECK(device.has_id); + TF_RET_CHECK(HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)); + + // Store the device instance as an attr on the Node. + TF_RETURN_IF_ERROR(SetDeviceOrdinalAttributeForNode(node, device.id)); + + // Place the execute Op on the TPU_SYSTEM device so it can access the cache of + // compiled protos in the resource manager. + device.type = target_device_type; + device.id = 0; + + node->set_assigned_device_name(DeviceNameUtils::ParsedNameToString(device)); + return Status::OK(); +} + +// Iterate over the nodes in the original graph and find all the TPUReplicate +// nodes, and all the nodes that are part of outside_compilation clusters. +Status FindTaggedNodes( + Graph* graph, std::vector<Node*>* replicate_nodes, + std::map<string, DistributedTPURewritePass::OutsideCompilationNodeMap>* + outside_compilation_nodes, + std::map<string, std::vector<Node*>>* head_tail_outside_compilation_nodes) { + for (Node* node : graph->op_nodes()) { + if (node->type_string() == "_TPUReplicate") { + replicate_nodes->push_back(node); + const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr); + if (cluster_attr == nullptr) { + return errors::Internal("TPUReplicate node ", node->name(), " has no ", + kTPUReplicateAttr, " attr."); + } else { + const string& cluster = cluster_attr->s(); + if (cluster.empty()) { + return errors::Internal("Attr ", kTPUReplicateAttr, " on node ", + node->name(), " has no string value."); + } + if (outside_compilation_nodes->find(cluster) != + outside_compilation_nodes->end()) { + return errors::Internal( + "TPUReplicate node ", node->name(), " has ", kTPUReplicateAttr, + " attr value '", cluster, + "' which is a duplicate of another TPUReplicate node in the " + "graph."); + } + (*outside_compilation_nodes)[cluster] = + DistributedTPURewritePass::OutsideCompilationNodeMap(); + (*head_tail_outside_compilation_nodes)[cluster] = std::vector<Node*>(); + } + } + } + for (Node* node : graph->op_nodes()) { + if (node->type_string() != "_TPUReplicate") { + const AttrValue* cluster_attr = node->attrs().Find(kTPUReplicateAttr); + const AttrValue* outside_compilation_attr = + node->attrs().Find(kOutsideCompilationAttr); + if (cluster_attr == nullptr) { + if (outside_compilation_attr != nullptr) { + return errors::Internal("Node ", node->name(), " has ", + kOutsideCompilationAttr, " attr but no ", + kTPUReplicateAttr, " attr."); + } + } else { + const string& cluster = cluster_attr->s(); + if (cluster.empty()) { + return errors::Internal("Attr ", kTPUReplicateAttr, " on node ", + node->name(), " has no string value."); + } + const auto iter = outside_compilation_nodes->find(cluster); + if (iter == outside_compilation_nodes->end()) { + return errors::Internal( + "Attr ", kTPUReplicateAttr, " on node ", node->name(), + " does not correspond to a TPUReplicate node."); + } + if (outside_compilation_attr == nullptr) { + return errors::Internal("Node ", node->name(), " has ", + kTPUReplicateAttr, " attr but no ", + kOutsideCompilationAttr, " attr."); + } + const string& oc_cluster = outside_compilation_attr->s(); + if (oc_cluster.empty()) { + return errors::Internal("Attr ", kOutsideCompilationAttr, " on node ", + node->name(), " has no string value."); + } + + // Outside compilation cluster at head and tail of TPU computation has + // already been moved to host and is already replicated. As so, do not + // replicate outside compilation nodes with replica id attribute. + int replica_id; + if (TryGetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id)) { + const AttrValue* head_attr = + node->attrs().Find("_xla_only_arg_or_oc_input"); + const AttrValue* tail_attr = + node->attrs().Find("_xla_only_ret_or_oc_output"); + if (((head_attr != nullptr) && (head_attr->b())) || + ((tail_attr != nullptr) && (tail_attr->b()))) { + // This is safe as this has the same keys as + // outside_compilation_nodes which we already know has this key. + (*head_tail_outside_compilation_nodes)[cluster].push_back(node); + } + continue; + } + iter->second[oc_cluster].push_back(node); + } + } + } + return Status::OK(); +} + +// Helper class to spread TPU computation arguments and return values +// across cores. +// If all shapes are fully defined, balance by their size. +// If some of them are not fully defined, the undefined shapes size will +// be estimated with the average size of the fully defined ones. +// If none are defined, fall back to round-robin. +class TensorDevicePlacer { + public: + // Creates a TensorDevicePlacer object to distribute arguments or + // return values to a set of num_devices devices, where the types and + // the inferred shapes of the inputs (arguments or return values) are + // passed in types and shapes. + TensorDevicePlacer(int64 num_devices, const DataTypeVector& types, + const std::vector<InferredShape>& shapes) + : index_nodes_(num_devices), sizes_(types.size()) { + int64 total_size = 0; + int64 num_defined = 0; + for (int64 i = 0; i < types.size(); ++i) { + sizes_[i] = GetInferredShapeSize(shapes[i], types[i]); + if (sizes_[i] >= 0) { + total_size += sizes_[i]; + ++num_defined; + } + } + // If a shape is undefined, select a size for it which is the average + // of the defined shapes. If no shapes are defined, assign 1 so that we + // get round-robin behavior. + int64 undefined_shape_size = + (num_defined > 0) ? total_size / num_defined : 1; + for (int64 i = 0; i < sizes_.size(); ++i) { + if (sizes_[i] < 0) { + sizes_[i] = undefined_shape_size; + } + } + + for (int64 i = 0; i < num_devices; ++i) { + heap_.Push(&index_nodes_[i]); + } + } + + // Reports that the argument/return-value at index has been assigned + // by the user to a given device. + void ReportDeviceAssigned(int64 device, int64 index) { + DeviceNode* node = &index_nodes_.at(device); + node->size += sizes_.at(index); + heap_.Adjust(node); + } + + // Retrieves the device at which the argument/return-value at index + // should be assigned to. + int64 RetrieveAssignment(int64 index) { + DeviceNode* node = heap_.top(); + int64 device = node - index_nodes_.data(); + node->size += sizes_.at(index); + heap_.Adjust(node); + return device; + } + + private: + struct DeviceNode { + struct Compare { + // Compare functor to implement a min heap using the ::gtl::IntrusiveHeap + // infrastructure. + bool operator()(const DeviceNode* lhs, const DeviceNode* rhs) const { + return lhs->size < rhs->size; + } + }; + + IntrusiveHeapLink heap; + int64 size = 0; + }; + + static int64 GetInferredShapeSize(const InferredShape& ishape, + DataType dtype) { + return ishape.shape.IsFullyDefined() + ? ishape.shape.num_elements() * DataTypeSize(dtype) + : -1; + } + + std::vector<DeviceNode> index_nodes_; + IntrusiveHeap<DeviceNode, typename DeviceNode::Compare> heap_; + std::vector<int64> sizes_; +}; + +Status ValidateCoreNumber(int64 core, int64 num_cores_per_replica) { + if (core < 0 || core >= num_cores_per_replica) { + return tensorflow::errors::InvalidArgument("Invalid core ID: ", core, + ". The valid core IDs are [0..", + num_cores_per_replica, ")"); + } + return Status::OK(); +} + +Status FindHostComputeKeyPlaceholderNodes( + const Graph* graph, const std::vector<Node*>& replicate_nodes, + std::unordered_map<string, Node*>* host_compute_key_placeholder_map) { + host_compute_key_placeholder_map->clear(); + for (const auto node : replicate_nodes) { + (*host_compute_key_placeholder_map)[node->name()] = nullptr; + } + + for (Node* node : graph->op_nodes()) { + if (node->type_string() == "Placeholder" && + str_util::EndsWith(node->name(), "_key_placeholder")) { + const AttrValue* call_node_attr = + node->attrs().Find("_host_compute_call_node"); + if (call_node_attr != nullptr) { + auto iter = host_compute_key_placeholder_map->find(call_node_attr->s()); + if (iter == host_compute_key_placeholder_map->end()) { + return errors::InvalidArgument( + "Node ", node->name(), " has _host_compute_call_node attribute '", + call_node_attr->s(), "' that doesn't correspond to a call node"); + } + if (iter->second != nullptr) { + return errors::InvalidArgument( + "Key placeholder node ", iter->second->name(), " for call node ", + call_node_attr->s(), " previously found as ", + iter->second->name()); + } + iter->second = node; + } + } + } + + return Status::OK(); +} + +Status ReplaceCompilationResultNodeWithIdentity(Graph* graph, Node** node) { + Node* old_node = *node; + // We want to replace the node with an identity node with the same name. + const string& node_name = old_node->name(); + + // Create identity node. + TF_ASSIGN_OR_RETURN( + Node * id_node, + BuildIdentityNode(graph, node_name, DT_STRING, + /*input=*/nullptr, /*requested_device=*/"")); + + // No incoming edges are copied as a new one will be added from compile node + // to id_node. + + // Copy outgoing edges to the id node. + std::vector<const Edge*> out_edges(old_node->out_edges().begin(), + old_node->out_edges().end()); + for (const Edge* edge : out_edges) { + Node* dst = edge->dst(); + int src_output = edge->src_output(); + int dst_input = edge->dst_input(); + + if (src_output == Graph::kControlSlot) { + graph->AddControlEdge(id_node, dst); + } else { + graph->AddEdge(id_node, src_output, dst, dst_input); + } + graph->RemoveEdge(edge); + } + graph->RemoveNode(old_node); + + *node = id_node; + return Status::OK(); +} + +Status FillPaddingMap( + const Node& replicate_node, + protobuf::RepeatedPtrField<tpu::PaddingMap>* padding_maps) { + std::vector<string> padding_map_strs; + TF_RETURN_IF_ERROR( + GetNodeAttr(replicate_node.attrs(), "padding_map", &padding_map_strs)); + padding_maps->Reserve(padding_map_strs.size()); + for (const string& padding_map_str : padding_map_strs) { + tpu::PaddingMap* padding_map = padding_maps->Add(); + if (!padding_map->ParseFromString(padding_map_str)) { + return errors::InvalidArgument( + "Malformed padding_map serialized string: ", padding_map_str); + } + } + return Status::OK(); +} + +Status GetStepMarkerLocation(const Node& replicate_node, + xla::DebugOptions::StepMarkerLocation* location) { + string step_marker_location_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "step_marker_location", + &step_marker_location_attr)); + if (step_marker_location_attr.empty()) { + *location = xla::DebugOptions::STEP_MARK_AT_ENTRY; + } else { + if (!xla::DebugOptions::StepMarkerLocation_Parse(step_marker_location_attr, + location)) { + return errors::InvalidArgument("Malformed step_marker_location: ", + step_marker_location_attr); + } + } + return Status::OK(); +} + +// Extracts a map of dimension and number of splits for tiled input from xla +// sharding attribute. +Status GetDimensionIndicesAndNumSplitsFromSharding( + const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) { + for (int dim_index = 0; + dim_index < sharding.tile_assignment_dimensions_size(); dim_index++) { + if (sharding.tile_assignment_dimensions(dim_index) > 1) { + split_dimension_map->emplace( + dim_index, sharding.tile_assignment_dimensions(dim_index)); + } + } + + if (split_dimension_map->empty()) { + return errors::InvalidArgument("Arg has unnecessary tiled sharding: ", + sharding.DebugString()); + } + return Status::OK(); +} + +// Updates contents of the function with `function_name` in function library +// definition `flib_def` to `new_graph`. This is required when graph +// transformation happens inside a function call body. +Status UpdateFunctionLibDefinition(const Graph& new_graph, + const std::string& function_name, + FunctionLibraryDefinition* flib_def) { + FunctionDef graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(new_graph, function_name, &graph_fdef)); + TF_RETURN_IF_ERROR(flib_def->ReplaceFunction(function_name, graph_fdef)); + return Status::OK(); +} + +struct NodeOut { + Node* node; + int index; +}; + +struct ShardedInputIndex { + int replica_id; + int argument_index; + + bool operator<(const ShardedInputIndex& rhs) const { + return std::tie(replica_id, argument_index) < + std::tie(rhs.replica_id, rhs.argument_index); + } +}; + +struct ShardedInputInfo { + // Split node that would be connected to tiled input Node. + Node* split_node; + // List of splits nodes and output index of the split node from which sharded + // input will be connected to the TPUExecute node. The inputs are ordered by + // logical core ids. + std::vector<NodeOut> sharded_inputs; +}; + +// Adds split node and split dimension node to graph for sharding tiled inputs. +// |graph| owns the returned Node* instance. +xla::StatusOr<Node*> CreateSplitNode(int num_splits, int dim, + int orig_src_output, DataType dtype, + absl::string_view name_prefix, + Node* control_predecessor, Node* orig_src, + Graph* graph) { + const std::string input_assigned_device = orig_src->assigned_device_name(); + + // Add a split dimension node. + NodeDef split_dim_def; + split_dim_def.set_name( + graph->NewName(absl::StrCat(name_prefix, "/split_dim"))); + split_dim_def.set_op("Const"); + split_dim_def.set_device(input_assigned_device); + AddNodeAttr("dtype", DT_INT32, &split_dim_def); + TensorProto tensor_proto; + tensor_proto.set_dtype(DT_INT32); + tensor_proto.add_int_val(dim); + TensorShape shape({}); + shape.AsProto(tensor_proto.mutable_tensor_shape()); + AddNodeAttr("value", tensor_proto, &split_dim_def); + Status s; + Node* split_dim_node = graph->AddNode(split_dim_def, &s); + TF_RETURN_IF_ERROR(s); + // Add a split node. + NodeDef split_def; + split_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/split"))); + split_def.set_op("Split"); + split_def.set_device(input_assigned_device); + AddNodeAttr("num_split", num_splits, &split_def); + AddNodeAttr("T", dtype, &split_def); + split_def.add_input(absl::StrCat(split_dim_node->name(), ":0")); + split_def.add_input(absl::StrCat(orig_src->name(), ":", orig_src_output)); + Node* split_node = graph->AddNode(split_def, &s); + TF_RETURN_IF_ERROR(s); + + graph->AddEdge(split_dim_node, 0, split_node, 0); + graph->AddEdge(orig_src, orig_src_output, split_node, 1); + + // Add a control dependency from `control_predecessor` to newly created + // constant node. This ensures that newly added split/split dim + // nodes are placed inside correct while loop frames when TPUExecute + // node is inside a host training loop. + graph->AddControlEdge(control_predecessor, split_dim_node); + + return split_node; +} + +// Creates a set of splits nodes that shards tiled input node in graph. +xla::StatusOr<ShardedInputInfo> CreateOrGetSplitNodesForInputSharding( + const xla::OpSharding& sharding, int orig_arg_num, DataType dtype, + int replica_id, int orig_src_output, Node* orig_src, + Node* control_predecessor, Graph* graph, + std::map<ShardedInputIndex, ShardedInputInfo>* + arg_index_to_sharded_input_map) { + ShardedInputIndex input_index{replica_id, orig_arg_num}; + auto iter = arg_index_to_sharded_input_map->find(input_index); + if (iter != arg_index_to_sharded_input_map->end()) { + return iter->second; + } + // Maps input dimension and number of splits with which the + // dimension sharded. + std::map<int, int> split_dimension_map; + TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding( + sharding, &split_dimension_map)); + TF_RET_CHECK(!split_dimension_map.empty()) + << "Unnecessary sharding attribute found."; + + // For v1 while loop, nodes inside the loop body must either + // 1) Have data edges from while loop input node. + // or + // 2) Have direct control dependency from while loop input control + // node. + // + // As so, if we are adding Split node inside, while loop body, + // we must manually add a control dependency to a node inside + // a while loop (i.e. `control_predecessor`) to constant nodes + // without data in-edges to make sure that added split nodes + // have correct frame name. Else, placer will complain when + // `BuildControlFlow()` is invoked. + + auto sharding_it = split_dimension_map.begin(); + std::queue<Node*> split_nodes_for_dimension; + int split_dimension = sharding_it->first; + int num_split = sharding_it->second; + + // Creates a tree of split nodes for sharding tiled inputs. Splits nodes + // are created such that input data is sharded in row major order. + // Split nodes at ith depth from the original input node represent nodes + // that split the input data at ith dimension. + TF_ASSIGN_OR_RETURN( + Node * root_split_node, + CreateSplitNode(num_split, split_dimension, orig_src_output, dtype, + absl::StrCat("sharded_input/replica_", replica_id, + "_dim_", split_dimension), + control_predecessor, orig_src, graph)); + sharding_it++; + + split_nodes_for_dimension.emplace(root_split_node); + + while (sharding_it != split_dimension_map.end()) { + split_dimension = sharding_it->first; + num_split = sharding_it->second; + int num_split_nodes_in_dimension = split_nodes_for_dimension.size(); + for (int i = 0; i < num_split_nodes_in_dimension; ++i) { + Node* input_split_node = split_nodes_for_dimension.front(); + split_nodes_for_dimension.pop(); + for (int src_output_index = 0; + src_output_index < input_split_node->num_outputs(); + ++src_output_index) { + TF_ASSIGN_OR_RETURN( + Node * split_node, + CreateSplitNode(num_split, split_dimension, src_output_index, dtype, + absl::StrCat("sharded_input/replica_", replica_id, + "_dim_", split_dimension), + control_predecessor, input_split_node, graph)); + split_nodes_for_dimension.emplace(split_node); + } + } + sharding_it++; + } + + // `split_nodes_for_dimension` now includes final split nodes + // from which sharded data will be fed into TPUExcute nodes -- sorted by + // row major order. + std::vector<NodeOut> sharded_inputs_list; + sharded_inputs_list.reserve(split_nodes_for_dimension.size()); + while (!split_nodes_for_dimension.empty()) { + Node* split_node = split_nodes_for_dimension.front(); + split_nodes_for_dimension.pop(); + int num_splits; + TF_RETURN_IF_ERROR( + GetNodeAttr(split_node->def(), "num_split", &num_splits)); + for (int out_index = 0; out_index < num_splits; ++out_index) { + sharded_inputs_list.emplace_back(NodeOut{split_node, out_index}); + } + } + + ShardedInputInfo sharded_input_info{root_split_node, + std::move(sharded_inputs_list)}; + (*arg_index_to_sharded_input_map)[input_index] = sharded_input_info; + return sharded_input_info; +} + +// Creates a concat node to be used for aggregating sharded retvals across +// logical cores. +xla::StatusOr<Node*> CreateConcatNode(int dim, int num_splits, DataType dtype, + absl::string_view name_prefix, + const std::vector<NodeOut>& inputs, + Graph* graph, absl::string_view device) { + // Add a Concat dim node. + NodeDef concat_dim_def; + concat_dim_def.set_name( + graph->NewName(absl::StrCat(name_prefix, "/concat_dim"))); + concat_dim_def.set_op("Const"); + AddNodeAttr("dtype", DT_INT32, &concat_dim_def); + concat_dim_def.set_device(std::string(device)); + TensorProto tensor_proto; + tensor_proto.set_dtype(DT_INT32); + tensor_proto.add_int_val(dim); + TensorShape shape({}); + shape.AsProto(tensor_proto.mutable_tensor_shape()); + AddNodeAttr("value", tensor_proto, &concat_dim_def); + Status s; + Node* concat_dim_node = graph->AddNode(concat_dim_def, &s); + TF_RETURN_IF_ERROR(s); + + // Add a Concat node. + NodeDef concat_def; + concat_def.set_name(graph->NewName(absl::StrCat(name_prefix, "/concat"))); + concat_def.set_op("Concat"); + AddNodeAttr("N", num_splits, &concat_def); + AddNodeAttr("T", dtype, &concat_def); + concat_def.add_input(absl::StrCat(concat_dim_node->name(), ":0")); + concat_def.set_device(std::string(device)); + for (const auto& i : inputs) { + concat_def.add_input(absl::StrCat(i.node->name(), ":", i.index)); + } + Node* concat_node = graph->AddNode(concat_def, &s); + TF_RETURN_IF_ERROR(s); + + graph->AddEdge(concat_dim_node, 0, concat_node, 0); + + // 0th input to concat node is a concat dim node. So we start from 1st input + // and add all input edges. + int dst_input = 1; + for (const auto& i : inputs) { + graph->AddEdge(i.node, i.index, concat_node, dst_input); + ++dst_input; + } + return concat_node; +} + +// Creates a set of Concat nodes that aggregates sharded outputs from TPUExecute +// nodes into a single output. Sharded outputs are concatenated along row major +// order. That is, tiled output along 0th dimension will be concatenated last. +xla::StatusOr<Node*> CreateConcatNodesForRetval( + const xla::OpSharding& sharding, DataType dtype, int replica_id, + const std::vector<NodeOut>& orig_inputs, Graph* graph, + absl::string_view device) { + std::map<int, int> split_dimension_map; + TF_RETURN_IF_ERROR(GetDimensionIndicesAndNumSplitsFromSharding( + sharding, &split_dimension_map)); + + std::vector<NodeOut> inputs_to_sharded_retval = orig_inputs; + + for (auto it = split_dimension_map.rbegin(); it != split_dimension_map.rend(); + it++) { + auto dim = it->first; + auto num_splits = it->second; + + int num_concat_nodes = inputs_to_sharded_retval.size() / num_splits; + int input_index_to_concat_node = 0; + + std::vector<NodeOut> new_concat_nodes; + for (int i = 0; i < num_concat_nodes; ++i) { + auto concat_input_it = + inputs_to_sharded_retval.begin() + input_index_to_concat_node; + std::vector<NodeOut> inputs(concat_input_it, + concat_input_it + num_splits); + input_index_to_concat_node += num_splits; + + TF_ASSIGN_OR_RETURN( + Node * concat_node, + CreateConcatNode( + dim, num_splits, dtype, + absl::StrCat("sharded_output/replica_", replica_id, "_dim_", dim), + inputs, graph, device)); + new_concat_nodes.emplace_back(NodeOut{concat_node, 0}); + } + inputs_to_sharded_retval = new_concat_nodes; + } + + TF_RET_CHECK(inputs_to_sharded_retval.size() == 1); + return inputs_to_sharded_retval.at(0).node; +} + +absl::optional<int> GetCoreIndexInSharding(const xla::OpSharding& sharding, + int64 core) { + absl::optional<int> output_index; + for (int i = 0; i < sharding.tile_assignment_devices_size(); i++) { + int64 assigned_core = sharding.tile_assignment_devices(i); + if (assigned_core == core) { + output_index = i; + break; + } + } + return output_index; +} + +// Set the padding ops the same devices as the original inputs. If the original +// inputs are on TPUs, the padding ops will be placed on TPUs and XLA on demand +// mode will be triggered, so we don't need to copy the data back to the host +// to do the padding. +Status SetPaddingNodesDevices(Graph* graph) { + for (Node* n : graph->op_nodes()) { + bool tpu_padding_attr; + if (n->type_string() == "Pad" && + GetNodeAttr(n->attrs(), kPostDeviceRewriteAttr, &tpu_padding_attr) + .ok()) { + Node* unpadded_input; + TF_RETURN_IF_ERROR(n->input_node(0, &unpadded_input)); + + const string& requested_device = unpadded_input->requested_device(); + const string& assigned_device = unpadded_input->assigned_device_name(); + if (!requested_device.empty() || !assigned_device.empty()) { + // The output nodes of the original unpadded inputs include the padded + // inputs and real shapes of inputs, we assign those to the same device + // as the original inputs. + for (Node* out : unpadded_input->out_nodes()) { + if (GetNodeAttr(out->attrs(), kPostDeviceRewriteAttr, + &tpu_padding_attr) + .ok()) { + out->set_requested_device(requested_device); + out->set_assigned_device_name(assigned_device); + } + } + // There might be a tf.shape node added before TPUCompileOp, we need to + // set its device as well. + for (Node* out : n->out_nodes()) { + if (n->type_string() == "Shape") { + out->set_requested_device(requested_device); + out->set_assigned_device_name(assigned_device); + } + } + } + } + } + return Status::OK(); +} + +const string& AssignedOrRequestedDevice(const Node* node) { + if (!node->assigned_device_name().empty()) { + return node->assigned_device_name(); + } + return node->requested_device(); +} + +bool IsTpuDevice(const string& device_string) { + DeviceNameUtils::ParsedName device; + return DeviceNameUtils::ParseFullName(device_string, &device) && + device.type == DEVICE_TPU_NODE; +} + +// Returns a set of device ops can be placed on TPU. There is no strict rule of +// thumb to decide which ops should be in the list, but empirically they are +// mostly dummy ops like Identity-like ops or control flow related ops. However +// people can add also add other ops like Pad to allow data stay on TPU. +const absl::flat_hash_set<std::string>& PlaceOnTPUOpList() { + static const auto place_on_tpu_ops = new absl::flat_hash_set<std::string>( + {"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge", + "NextIteration", "Shape"}); + return *place_on_tpu_ops; +} + +// If an op satisfies the following conditions, it will be placed on the same +// TPU device as its inputs: +// (1) The op can be placed on TPU (in the PlaceOnTPUOpList) +// (2) The op itself has no requested or assigned devices. +// (3) All the data inputs of this op are placed on the same device on TPUs. +// There are exceptions like the NextIterations input of Switch node can +// be placed on CPU as it is just a boolean. +// +// Returns true if the node device has been changed, otherwise returns false. +bool PlaceOpsOnTPU(Node* node) { + if (!AssignedOrRequestedDevice(node).empty() || + !PlaceOnTPUOpList().contains(node->type_string())) { + return false; + } + string src_tpu_device = ""; + Node* src_node; + for (const Edge* e : node->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + Node* src = e->src(); + const string& src_device = AssignedOrRequestedDevice(src); + + // Make exceptions that we don't force the some inputs to place on TPUs. + if (node->IsSwitch() && src->IsLoopCond()) { + continue; + } + + if (!IsTpuDevice(src_device) || + (!src_tpu_device.empty() && src_device != src_tpu_device)) { + return false; + } + if (src_tpu_device.empty()) { + src_tpu_device = src_device; + src_node = src; + } + } + node->set_assigned_device_name(src_node->assigned_device_name()); + node->set_requested_device(src_node->requested_device()); + return true; +} + +// Validate sharding configuration derived from XlaSharding attribute. +// Infer the core id from the OpSharding, if necessary. +Status ParseAndValidateSharding(const xla::OpSharding& sharding, + const int num_cores_per_replica, + int64* inferred_core_id, + absl::optional<xla::OpSharding>* result) { + if (sharding.type() == xla::OpSharding::MAXIMAL) { + int64 core_annotation = sharding.tile_assignment_devices(0); + TF_RETURN_IF_ERROR( + ValidateCoreNumber(core_annotation, num_cores_per_replica)); + if (*inferred_core_id == -1 || *inferred_core_id > core_annotation) { + *inferred_core_id = core_annotation; + result->emplace(sharding); + } + } else { + if (sharding.type() == xla::OpSharding::OTHER) { + for (int64 core : sharding.tile_assignment_devices()) { + TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica)); + } + } + + if (!result->has_value()) { + *result = sharding; + } else { + std::string result_value_serialized; + std::string sharding_serialized; + SerializeToStringDeterministic(result->value(), &result_value_serialized); + SerializeToStringDeterministic(sharding, &sharding_serialized); + + if (result_value_serialized != sharding_serialized) { + // We see different shardings, assign to core 0. + result->emplace(xla::sharding_builder::AssignDevice(0)); + } + } + } + return Status::OK(); +} + +// As XlaSharding node may be followed by Cast op or an Identity op, +// recursively walk the graph and aggregate nodes connectd to +// |input_node| or Cast/Identity op following the |input_node|. +void FindNodesMaybeContainingShardingInfo(const Node& input_node, + std::vector<const Node*>* nodes) { + if (input_node.IsIdentity() || input_node.type_string() == "Cast") { + for (const Node* connected_node : input_node.out_nodes()) + FindNodesMaybeContainingShardingInfo(*connected_node, nodes); + } + nodes->emplace_back(&input_node); +} + +// Parse sharding configuration from |node| or it's adjacent nodes. +// XlaSharding configuration may be derived from +// a) Connected Identity op node. +// b) Connected Cast op node. +xla::StatusOr<absl::optional<xla::OpSharding>> +ParseInputShardingFromAdjacentNode(const int num_cores_per_replica, + const Node& node) { + // If |node| has `device` attribute or is a XlaSharding op, + // return the parsed OpSharding. + TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, + ParseShardingFromDevice(node, num_cores_per_replica)); + if (sharding.has_value()) return sharding; + + // XlaShardingOp may be followed by an identity or followed by identity + // and a Cast op. + std::vector<const Node*> potential_nodes_with_input_sharding; + FindNodesMaybeContainingShardingInfo(node, + &potential_nodes_with_input_sharding); + for (const Node* maybe_node_with_sharding_info : + potential_nodes_with_input_sharding) { + if (maybe_node_with_sharding_info->type_string() != "XlaSharding") continue; + + TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding_config, + ParseShardingFromDevice(*maybe_node_with_sharding_info, + num_cores_per_replica)); + if (sharding_config.has_value()) return sharding_config; + } + return sharding; +} + +// Walk the graph from an argument node to find OpSharding configuration +// from its neighbor nodes. Sharding configuration may be inferred from +// 1) Parsing XlaSharding attribute from neighboring node. +// 2) If argument node is a resource, then by parsing adjacent nodes +// of the connected ReadVariable op. +Status ParseAndValidateShardingFromNeighbors( + const int num_cores_per_replica, const std::string& arg_node_name, + const Node& neighbor_node, int64* inferred_core_id, bool* is_fast_mem, + absl::optional<xla::OpSharding>* result) { + if (neighbor_node.attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) { + *is_fast_mem = true; + VLOG(2) << "place " << neighbor_node.name() << " on fast memory because " + << arg_node_name << " has " << TPU_FAST_MEM_ATTR << " attribute"; + } + + // XlaSharding information may be encoded on node directly connected to the + // argument node. + TF_ASSIGN_OR_RETURN( + absl::optional<xla::OpSharding> sharding, + ParseInputShardingFromAdjacentNode(num_cores_per_replica, neighbor_node)); + if (sharding.has_value()) { + TF_RETURN_IF_ERROR(ParseAndValidateSharding( + *sharding, num_cores_per_replica, inferred_core_id, result)); + return Status::OK(); + } + + // When we use variable in TPU computation, we always have a + // XlaSharding op followed by a ReadVariableOp. As so, correctly parse + // the users of ReadVariableOp for potential sharding configuration. + if (neighbor_node.type_string() == "ReadVariableOp") { + for (const Edge* e : neighbor_node.out_edges()) { + if (e->IsControlEdge()) continue; + + if (e->dst()->attrs().Find(TPU_FAST_MEM_ATTR) != nullptr) { + *is_fast_mem = true; + VLOG(2) << "place " << arg_node_name << " on fast memory because " + << e->dst()->name() << TPU_FAST_MEM_ATTR << " attribute"; + } + + TF_ASSIGN_OR_RETURN( + absl::optional<xla::OpSharding> sharding, + ParseInputShardingFromAdjacentNode(num_cores_per_replica, *e->dst())); + if (sharding.has_value()) { + TF_RETURN_IF_ERROR(ParseAndValidateSharding( + *sharding, num_cores_per_replica, inferred_core_id, result)); + return Status::OK(); + } + } + } + return Status::OK(); +} + +} // namespace + +// Inputs: +// replication_spec_string: the device to which the TPUReplicate node was +// assigned. +// device_set: the set of TF devices. +// Outputs: +// tpu_compilation_device: the name of the TPU compilation device. +// num_tpus_per_task: the number of TPUs in each task. Verifies that all tasks +// have the same number of TPU devices. +// tpu_devices: the TPU devices, indexed by [task][device]. +static Status GetTPUDeviceNames( + const string& replication_spec_string, const DeviceSet& device_set, + string* tpu_compilation_device, int* num_tpus_per_task, + std::vector<std::vector<Device*>>* tpu_devices) { + // TODO(b/110910013) GetSystemDevice parses the spec and returns the name of + // the tpu_system device, which we replace by the cpu device. We do this + // replacement because we want to place the TPUCompileOp (and the compile + // assert op) explicitly on cpu devices on the same job as the tpu_system + // device. + DeviceNameUtils::ParsedName replication_spec; + Device* replication_device; + TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetSystemDevice( + replication_spec_string, device_set, &replication_spec, + &replication_device)); + *tpu_compilation_device = + str_util::StringReplace(replication_device->name(), DEVICE_TPU_SYSTEM, + DEVICE_CPU, /*replace_all=*/true); + + // Finds the set of TPU devices attached to the tasks in the job. + TF_RETURN_IF_ERROR(DistributedTPURewriteHelpers::GetTPUDevices( + replication_spec, device_set, num_tpus_per_task, tpu_devices)); + + return Status::OK(); +} + +// Parses the topology attribute of TPUReplicate, and populates *topology with +// a physical mesh coordinate to (task, device) mapping. +static Status ParseTopologyAttr(const string& topology_attr, + const tpu::TpuTopologyExternal& tpu_topology, + int num_tasks, int num_tpus_per_task, + xla::Array4D<std::pair<int, int>>* topology) { + static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4"); + tpu::TopologyProto proto; + proto.ParseFromString(topology_attr); + if (proto.mesh_shape_size() != kTPUTopologyRank) { + return errors::InvalidArgument("TPU topology must be rank ", + kTPUTopologyRank); + } + if (proto.num_tasks() != num_tasks) { + return errors::InvalidArgument("Mismatched number of TPU tasks"); + } + if (proto.num_tpu_devices_per_task() != num_tpus_per_task) { + return errors::InvalidArgument("Mismatched number of TPUs per task (", + proto.num_tpu_devices_per_task(), + " != ", num_tpus_per_task, ")."); + } + if (proto.device_coordinates_size() != + num_tasks * num_tpus_per_task * kTPUTopologyRank) { + return errors::InvalidArgument( + "device coordinates should be ", num_tasks, "x", num_tpus_per_task, "x", + kTPUTopologyRank, "; got ", proto.device_coordinates_size()); + } + + int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore); + *topology = xla::Array4D<std::pair<int, int>>( + tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y, + tpu_topology.chip_bounds().z, devices_per_chip, {-1, -1}); + int pos = 0; + for (int task = 0; task < num_tasks; ++task) { + for (int device = 0; device < num_tpus_per_task; ++device) { + int32 x = proto.device_coordinates(pos++); + int32 y = proto.device_coordinates(pos++); + int32 z = proto.device_coordinates(pos++); + int32 core = proto.device_coordinates(pos++); + + if (!tpu_topology.HasChip(x, y, z) || core < 0 || + core >= devices_per_chip) { + return errors::InvalidArgument( + "Mesh coordinates (", x, ",", y, ",", z, ",", core, + ") are not valid for the current TPU topology"); + } + if ((*topology)(x, y, z, core).first != -1) { + return errors::InvalidArgument("Duplicate coordinates (", x, ",", y, + ",", z, ",", core, ") in TPU topology"); + } + (*topology)(x, y, z, core) = {task, device}; + } + } + return Status::OK(); +} + +// Parses the value of the device_assignment attribute to TPUReplicate. +// Populates *device_assignment; *device_assignment must be a 2D array with +// shape (num_replicas, num_cores_per_replica). +static Status ParseDeviceAssignmentAttr( + absl::Span<const int> device_assignment_attr, + const tpu::TpuTopologyExternal& tpu_topology, int num_replicas, + int num_cores_per_replica, + xla::Array2D<tpu::TpuCoreLocationExternal>* device_assignment) { + static_assert(4 == kTPUTopologyRank, "Assumes the topology rank is 4"); + + const int64 device_assignment_attr_size = + num_replicas * num_cores_per_replica * kTPUTopologyRank; + if (device_assignment_attr.size() != device_assignment_attr_size) { + return errors::InvalidArgument( + "Length of device_assignment attribute must be equal to num_replicas (", + num_replicas, ") * num_cores_per_replica (", num_cores_per_replica, + ") * ", kTPUTopologyRank, " got ", device_assignment_attr.size()); + } + for (int core : device_assignment_attr) { + if (core < 0 || core >= kTPUMaxTopologySize) { + return errors::InvalidArgument( + "Invalid core number in device assignment: ", core); + } + } + + *device_assignment = xla::Array2D<tpu::TpuCoreLocationExternal>( + num_replicas, num_cores_per_replica); + int devices_per_chip = tpu_topology.LogicalDevicesPerChip(kTensorCore); + xla::Array4D<int> replica_assignment( + tpu_topology.chip_bounds().x, tpu_topology.chip_bounds().y, + tpu_topology.chip_bounds().z, devices_per_chip, -1); + int pos = 0; + for (int replica = 0; replica < num_replicas; ++replica) { + for (int logical_core = 0; logical_core < num_cores_per_replica; + ++logical_core) { + int32 x = device_assignment_attr[pos++]; + int32 y = device_assignment_attr[pos++]; + int32 z = device_assignment_attr[pos++]; + int32 core = device_assignment_attr[pos++]; + + if (!tpu_topology.HasChip(x, y, z) || core < 0 || + core >= devices_per_chip) { + return errors::InvalidArgument( + "Mesh coordinates (", x, ",", y, ",", core, + ") are not valid for the current TPU topology"); + } + tpu::TpuCoreLocationExternal core_location = + tpu_topology.Core(x, y, z, kTensorCore, core); + + if (replica_assignment(x, y, z, core) != -1) { + return errors::InvalidArgument("Duplicate coordinates (", x, ",", y, + ",", z, ",", core, + ") in TPU device assignment"); + } + replica_assignment(x, y, z, core) = replica; + (*device_assignment)(replica, logical_core) = core_location; + } + } + return Status::OK(); +} + +// Builds TensorFlow device assignments for the special case of a single core +// computation that is replicated to every core in the mesh. +// LINT.IfChange +static Status BuildFullMeshDeviceAssignment( + int num_replicas, const std::vector<std::vector<Device*>>& tpu_devices, + int num_tasks, int num_tpus_per_task, + std::vector<std::vector<string>>* tf_device_assignment) { + // Assign TensorFlow devices to replicas arbitrarily. + for (int i = 0; i < num_replicas; ++i) { + int task = i / num_tpus_per_task; + int device = i % num_tpus_per_task; + TF_RET_CHECK(task >= 0 && task < num_tasks); + TF_RET_CHECK(device >= 0 && device < num_tpus_per_task); + + // We don't actually know which TF device corresponds to which physical + // device, but it doesn't matter—they're all identical. + (*tf_device_assignment)[i] = {tpu_devices[task][device]->name()}; + } + return Status::OK(); +} +// LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc) + +// Builds TensorFlow device assignments for a replicated computation and convert +// device_assignment into xla_device_assignment. +static Status BuildGeneralDeviceAssignment( + int num_replicas, int num_cores_per_replica, + const std::vector<std::vector<Device*>>& tpu_devices, + const xla::Array2D<tpu::TpuCoreLocationExternal>& device_assignment, + const xla::Array4D<std::pair<int, int>>& topology, + std::vector<std::vector<string>>* tf_device_assignment, + std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) { + // Assign TensorFlow devices to each computation's replicas according to + // device_assignment and 'topology'. + *xla_device_assignment = absl::make_unique<xla::DeviceAssignment>( + num_replicas, num_cores_per_replica); + for (int replica = 0; replica < num_replicas; ++replica) { + for (int computation = 0; computation < num_cores_per_replica; + ++computation) { + const tpu::TpuCoreLocationExternal& core_location = + device_assignment(replica, computation); + + int task; + int device; + std::tie(task, device) = + topology(core_location.chip_coordinates().x, + core_location.chip_coordinates().y, + core_location.chip_coordinates().z, core_location.index()); + + CHECK_LT(computation, num_cores_per_replica); + (**xla_device_assignment)(replica, computation) = core_location.Id(); + + // The communication pattern between replicas will be determined later by + // BuildAllReduceRing. + TF_RET_CHECK(task >= 0 && task < tpu_devices.size()); + TF_RET_CHECK(device >= 0 && device < tpu_devices[task].size()); + (*tf_device_assignment)[replica].push_back( + tpu_devices[task][device]->name()); + } + } + return Status::OK(); +} + +/*static*/ Status DistributedTPURewritePass::BuildDeviceAssignment( + const tpu::TpuTopologyExternal& tpu_topology, int num_tpus_per_task, + const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas, + int num_cores_per_replica, const string& topology_attr, + absl::Span<const int> device_assignment_attr, + std::vector<std::vector<string>>* tf_device_assignment, + std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment) { + const int num_tasks = tpu_devices.size(); + const int num_tpu_devices = num_tasks * num_tpus_per_task; + VLOG(2) << "num_tasks=" << num_tasks + << " num_tpus_per_task=" << num_tpus_per_task; + + // Checks num_replicas is sane first to avoid integer overflow. + if (num_replicas > num_tpu_devices) { +#ifdef PLATFORM_CLOUD_TPU + return errors::InvalidArgument("Requested num_replicas=", num_replicas, + " but there are only ", num_tpu_devices, + " cores in the TPU topology."); +#else + return errors::InvalidArgument("Requested num_replicas=", num_replicas, + " but there are only ", num_tpu_devices, + " cores in the TPU topology."); +#endif + } + if (num_replicas * num_cores_per_replica > num_tpu_devices) { + return errors::InvalidArgument( + "Requested num_replicas=", num_replicas, " with ", + num_cores_per_replica, " cores per replica, but there are only ", + num_tpu_devices, " cores in the TPU topology"); + } + + tf_device_assignment->clear(); + tf_device_assignment->resize(num_replicas); + + // Special case: we allow the user to omit the topology and device assignment + // information in two cases: + // * there is only one replica and one core per replica. In this case, we + // don't need to know topology information because we don't communicate with + // other cores. + // * the number of replicas is equal to the number of cores in the slice. In + // this case, all cores are running the same program so we don't need to + // know which is which. + if (topology_attr.empty()) { + // LINT.IfChange + if (num_replicas != 1 && num_replicas != num_tpu_devices) { + return errors::InvalidArgument( + "TPUReplicate asked to create ", num_replicas, + " replicas, but the number of cores in the TPU topology is ", + num_tpu_devices, + " and no TPU device assignment was supplied. " + "A TPU device assignment is required if the number of replicas is " + "not 1 or the number of cores in the topology (", + num_tpu_devices, ")"); + } + + if (num_cores_per_replica != 1) { + return errors::InvalidArgument( + "A TPU topology must be provided if num_cores_per_replica != 1"); + } + + if (!device_assignment_attr.empty()) { + return errors::InvalidArgument( + "A TPU topology must be provided if device_assignment_attr is " + "non-empty"); + } + + // If there is only one replica, assign the Tensorflow computation to task 0 + // device 0, and leave the XLA device assignment empty. We don't know which + // core this is in the TPU topology, but it doesn't matter—we don't need to + // communicate with any other cores. + if (num_replicas == 1) { + (*tf_device_assignment)[0] = {tpu_devices[0][0]->name()}; + return Status::OK(); + } + + // Otherwise, num_replicas is equal to the number of cores, and we build a + // device assignment that covers the entire mesh. We do not need to know + // the topology to do so because all cores are identical. + return BuildFullMeshDeviceAssignment(num_replicas, tpu_devices, num_tasks, + num_tpus_per_task, + tf_device_assignment); + // LINT.ThenChange(//tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc) + } + + // Array that maps mesh coordinates to {TF task, TF TPU device #} pairs. + xla::Array4D<std::pair<int, int>> topology; + TF_RETURN_IF_ERROR(ParseTopologyAttr(topology_attr, tpu_topology, num_tasks, + num_tpus_per_task, &topology)); + + // Array that maps logical (replica, core) pairs to physical mesh coordinates. + xla::Array2D<tpu::TpuCoreLocationExternal> device_assignment; + TF_RETURN_IF_ERROR(ParseDeviceAssignmentAttr( + device_assignment_attr, tpu_topology, num_replicas, num_cores_per_replica, + &device_assignment)); + + return BuildGeneralDeviceAssignment( + num_replicas, num_cores_per_replica, tpu_devices, device_assignment, + topology, tf_device_assignment, xla_device_assignment); +} + +Status DistributedTPURewritePass::GetComputationForTPUReplicateOp( + const NameAttrList& function, FunctionLibraryRuntime* flr, + Graph* computation, DataTypeVector* arg_types, + DataTypeVector* retval_types) { + FunctionLibraryRuntime::Handle handle; + + TF_RETURN_IF_ERROR( + flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle)); + + const FunctionBody* fbody = flr->GetFunctionBody(handle); + + CopyGraph(*fbody->graph, computation); + *arg_types = fbody->arg_types; + *retval_types = fbody->ret_types; + return Status::OK(); +} + +// Grab the InferredShape corresponding to an edge input. +static Status GetEdgeShape(const GraphShapeInfo& shape_info, const Edge& edge, + const InferredShape** info) { + auto it = shape_info.find(edge.src()->name()); + if (it == shape_info.end()) { + return errors::InvalidArgument( + "Input to replicated TPU computation is missing InferredShape: ", + edge.src()->name()); + } + TF_RET_CHECK(it->second.size() > edge.src_output()); + *info = &it->second[edge.src_output()]; + return Status::OK(); +} + +Status DistributedTPURewritePass::GetArgAndRetvalShapes( + const GraphShapeInfo& shape_info, const Node& node, + const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes, + std::vector<InferredShape>* retval_shapes) { + std::vector<const Edge*> input_edges; + TF_RETURN_IF_ERROR(node.input_edges(&input_edges)); + + // If any replica's arg shape is unknown, we will mark the computation's arg + // shape as being unknown. If the shapes differ the TpuExecute Op will raise a + // runtime error. + std::vector<bool> any_replica_shape_unknown( + params_info.NumInputsToEachReplica()); + arg_shapes->clear(); + arg_shapes->resize(params_info.NumInputsToEachReplica()); + TF_RET_CHECK(input_edges.size() == params_info.NumInputsFromHost()); + // Determines the shapes of the per-replica arguments and checks that all + // replicas have identical shapes. + int64 edge_pos = 0; + auto check_shape = [&](int input_index) -> Status { + const InferredShape* info; + TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info)); + ++edge_pos; + + if ((info->handle_type == DT_INVALID && !info->shape.IsFullyDefined()) || + (info->handle_type != DT_INVALID && + !info->handle_shape.IsFullyDefined())) { + any_replica_shape_unknown[input_index] = true; + } + xla::StatusOr<InferredShape> status = + MergeInferredShapes((*arg_shapes)[input_index], *info); + if (!status.ok()) { + return errors::InvalidArgument( + "Mismatched shapes for input ", input_index, ": ", + (*arg_shapes)[input_index].shape.DebugString(), " vs. ", + info->shape.DebugString()); + } + (*arg_shapes)[input_index] = status.ValueOrDie(); + return Status::OK(); + }; + + for (int64 i = 0; i < params_info.NumReplicas(); ++i) { + for (int64 j = 0; j < params_info.NumPerReplicaArgs(); ++j) { + TF_RETURN_IF_ERROR(check_shape(j)); + } + } + + for (int64 i = 0; i < params_info.NumDistributedArgs(); ++i) { + TF_RETURN_IF_ERROR(check_shape(params_info.NumPerReplicaArgs() + i)); + } + + for (int64 i = 0; + i < params_info.NumPerReplicaArgs() + params_info.NumDistributedArgs(); + ++i) { + if (any_replica_shape_unknown[i]) { + (*arg_shapes)[i].shape = PartialTensorShape(); + (*arg_shapes)[i].handle_shape = PartialTensorShape(); + } + } + + // Determines the shape of the broadcast arguments. + for (int64 i = 0; i < params_info.NumBroadcastArgs(); ++i) { + TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE); + const InferredShape* info; + TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info)); + (*arg_shapes)[i + params_info.NumPerReplicaArgs() + + params_info.NumDistributedArgs()] + .shape = info->shape; + ++edge_pos; + } + + // Determines the handle shape and handle type of the resource variable + // arguments. + for (int64 i = 0; i < params_info.NumVariables(); ++i) { + TF_RET_CHECK(node.input_type(edge_pos) == DT_RESOURCE); + const InferredShape* info; + TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info)); + InferredShape& arg_shape = + (*arg_shapes)[i + params_info.NumPerReplicaArgs() + + params_info.NumDistributedArgs() + + params_info.NumBroadcastArgs()]; + arg_shape.shape = TensorShape(); // Variables are always scalars. + arg_shape.handle_shape = info->handle_shape; + arg_shape.handle_type = info->handle_type; + TF_RET_CHECK(arg_shape.handle_type != DT_INVALID); + ++edge_pos; + } + + // Determines the shape of the guaranteed constants. + // TODO(vinuraja): Can be removed because they are not required for any + // calculations. Leaving them here for symmetry with other structures like + // arg_types, arg_sharding, etc. + for (int64 i = 0; i < params_info.NumGuaranteedConstants(); ++i) { + TF_RET_CHECK(node.input_type(edge_pos) != DT_RESOURCE); + const InferredShape* info; + TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info)); + (*arg_shapes)[i + params_info.NumPerReplicaArgs() + + params_info.NumDistributedArgs() + + params_info.NumBroadcastArgs() + params_info.NumVariables()] + .shape = info->shape; + ++edge_pos; + } + + // Extract the return value shapes. + auto it = shape_info.find(node.name()); + retval_shapes->clear(); + if (it != shape_info.end()) { + TF_RET_CHECK(it->second.size() >= node.num_outputs()); + retval_shapes->resize(node.num_outputs()); + for (int i = 0; i < node.num_outputs(); ++i) { + (*retval_shapes)[i].shape = it->second[i].shape; + } + } else if (node.num_outputs() > 0) { + return errors::InvalidArgument( + "Replicated TPU computation is missing InferredShape: ", + FormatNodeForError(node)); + } + return Status::OK(); +} + +// Verifies that all nodes have legal sharding. +static Status ValidateCoreNumbers(const Graph& graph, + int num_cores_per_replica) { + for (Node* n : graph.nodes()) { + TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> sharding, + ParseShardingFromDevice(*n, num_cores_per_replica)); + } + return Status::OK(); +} + +static Status InferXlaShardingFromNeighbors( + const Node& n, int num_cores_per_replica, FunctionLibraryRuntime* flr, + CachedFunctionHandles* cached_function_handles, + absl::optional<xla::OpSharding>* output_sharding, bool* is_fast_mem) { + int64 core = -1; + absl::optional<xla::OpSharding> result; + // We assume the variable has been allocated on fast memory if any consuming + // op has TPU_FAST_MEM_ATTR attribute. This is a protocol between runtime and + // compiler. + *is_fast_mem = false; + for (const Edge* edge : n.out_edges()) { + if (edge->IsControlEdge()) continue; + + TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors( + num_cores_per_replica, n.name(), *edge->dst(), &core, is_fast_mem, + &result)); + + if (!flr) continue; + + // The nodes deciding this arg's device assignment might be in + // FunctionDef. Instantiate FunctionDefs associated with this node + // and check nodes using this arg. + std::function<Status(const Edge* call_edge)> parse_sharding_from_function = + [&](const Edge* call_edge) { + auto associated_functions = GetAssociatedFunctions( + *call_edge->dst(), flr->GetFunctionLibraryDefinition()); + for (auto& associated_function : associated_functions) { + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(cached_function_handles->GetOrInstantiate( + associated_function.func_name(), + AttrSlice(&associated_function.attrs()), &handle)); + const FunctionBody* body = flr->GetFunctionBody(handle); + Graph* g = body->graph; + + for (Node* body_node : g->nodes()) { + if (!body_node->IsArg()) continue; + + int index; + TF_RETURN_IF_ERROR( + GetNodeAttr(body_node->attrs(), "index", &index)); + if (index != call_edge->dst_input()) continue; + + for (const Edge* out_edge : body_node->out_edges()) { + if (out_edge->IsControlEdge()) continue; + + TF_RETURN_IF_ERROR(ParseAndValidateShardingFromNeighbors( + num_cores_per_replica, n.name(), *out_edge->dst(), &core, + is_fast_mem, &result)); + + TF_RETURN_IF_ERROR(parse_sharding_from_function(out_edge)); + } + } + } + return Status::OK(); + }; + TF_RETURN_IF_ERROR(parse_sharding_from_function(edge)); + } + *output_sharding = result; + return Status::OK(); +} + +bool UseSpmdForXlaPartitioning(const Node* replicate_node) { + bool spmd_attr; + if (!replicate_node || + !TryGetNodeAttr(replicate_node->attrs(), "use_spmd_for_xla_partitioning", + &spmd_attr)) { + spmd_attr = false; + } + return spmd_attr; +} + +Status DistributedTPURewritePass::AssignArgsAndRetvalsToCores( + int num_cores_per_replica, const ParameterInfo& params_info, + const DataTypeVector& arg_types, + const std::vector<InferredShape>& arg_shapes, + const DataTypeVector& retval_types, + const std::vector<InferredShape>& retval_shapes, const Graph& graph, + const Node* replicate_node, FunctionLibraryRuntime* flr, + std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem, + std::vector<xla::OpSharding>* retval_sharding) { + // Builds vectors of the argument and return nodes. + std::vector<Node*> args(arg_types.size()); + std::vector<Node*> retvals(retval_types.size()); + absl::flat_hash_map<int, Node*> partitioned_output_nodes; + for (Node* node : graph.op_nodes()) { + if (node->IsArg()) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0 && index < args.size()); + args[index] = node; + } else if (node->IsRetval()) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0 && index < retvals.size()); + retvals[index] = node; + } + } + for (const Edge* edge : replicate_node->out_edges()) { + int num_partitioned_outputs = 0; + for (const Edge* out_edge : edge->dst()->out_edges()) { + if (out_edge->dst()->type_string() == kTPUPartitionedOutput) { + partitioned_output_nodes[edge->src_output()] = out_edge->dst(); + num_partitioned_outputs++; + } + } + if (num_partitioned_outputs > 1) { + return errors::InvalidArgument( + "More than one TPUPartitionedOutput per replciated output."); + } + } + + // Verifies there are no missing arguments/return values. + for (int i = 0; i < args.size(); ++i) { + if (args[i] == nullptr) { + return errors::Internal("Missing function argument: ", i); + } + } + for (int i = 0; i < retvals.size(); ++i) { + if (retvals[i] == nullptr) { + return errors::Internal("Missing function return value: ", i); + } + } + + // Assigns a core to each _Arg. Chooses the lowest-numbered core that + // consumes the argument. We choose the lowest-numbered core so the + // assignment is deterministic. + TensorDevicePlacer args_device_selector(num_cores_per_replica, arg_types, + arg_shapes); + arg_sharding->resize(args.size()); + arg_fast_mem->resize(args.size()); + CachedFunctionHandles cached_function_handles(flr); + const bool use_spmd = UseSpmdForXlaPartitioning(replicate_node) || + replicate_inputs_outputs_by_default_for_xla_spmd_; + for (int i = 0; i < args.size(); ++i) { + const Node* n = args[i]; + absl::optional<int64> assigned_core; + absl::optional<xla::OpSharding> sharding; + bool is_fast_mem; + TF_RETURN_IF_ERROR(InferXlaShardingFromNeighbors( + *n, num_cores_per_replica, flr, &cached_function_handles, &sharding, + &is_fast_mem)); + + if (params_info.IsPerReplicaArg(i) || params_info.IsDistributedArg(i)) { + Node* input_node; + TF_RETURN_IF_ERROR(replicate_node->input_node(i, &input_node)); + if (input_node->type_string() == kTPUPartitionedInput) { + TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> parsed_sharding, + GetShardingFromNodeDef(input_node->def())); + if (!parsed_sharding.has_value()) + return errors::InvalidArgument("Missing _XlaSharding attr from: ", + input_node->DebugString()); + sharding = parsed_sharding; + VLOG(1) << "Arg " << i << " parsed sharding information from " + << input_node->name() << " : " + << parsed_sharding->DebugString(); + } + } + + if (sharding.has_value() && enable_automatic_model_parallelism_) { + return tensorflow::errors::InvalidArgument( + "Specifying manual sharding is not allowed when automatic " + "model parallelism is enabled.", + sharding->DebugString()); + } + + if (!sharding.has_value()) { + if (use_spmd && + (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) || + ((params_info.IsPerReplicaArg(i) || + params_info.IsDistributedArg(i)) && + arg_types[i] != DT_RESOURCE))) { + // Use replication for host variables or non-variable per-replica + // inputs. + sharding = xla::sharding_builder::Replicate(); + } else { + // TODO(dlibenzi): Distributing variables to cores other than 0 makes + // learning/brain/research/babelfish/trainer:trainer_tpu_test fail. + // For now distribute only per replica arguments, unless + // tf_jf_distribute_vars is set, to allow debugging the issue. + if (((params_info.IsPerReplicaArg(i) || + params_info.IsDistributedArg(i)) && + arg_types[i] != DT_RESOURCE) || + (distribute_vars_ && params_info.IsVariableArg(i))) { + assigned_core = args_device_selector.RetrieveAssignment(i); + } else { + assigned_core = 0; + } + sharding = xla::sharding_builder::AssignDevice(*assigned_core); + } + } else if (sharding->type() == xla::OpSharding::MAXIMAL) { + assigned_core = sharding->tile_assignment_devices(0); + } else if (sharding->type() != xla::OpSharding::REPLICATED && + sharding->type() != xla::OpSharding::OTHER) { + return tensorflow::errors::InvalidArgument( + "Unsupported argument sharding: ", sharding->DebugString()); + } + if (assigned_core.has_value()) { + args_device_selector.ReportDeviceAssigned(*assigned_core, i); + VLOG(3) << "Assigning argument " << i << " (" << n->DebugString() + << ") to core " << *assigned_core; + args[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core)); + } else if (sharding->type() == xla::OpSharding::OTHER) { + for (int64 core : sharding->tile_assignment_devices()) { + args_device_selector.ReportDeviceAssigned(core, i); + VLOG(3) << "Assigning argument " << i << " (" << n->DebugString() + << ") with tiled sharding to core " << core; + } + } else { + CHECK_EQ(sharding->type(), xla::OpSharding::REPLICATED); + for (int64 core = 0; core < num_cores_per_replica; ++core) { + args_device_selector.ReportDeviceAssigned(core, i); + } + VLOG(3) << "Assigning argument " << i << " (" << n->DebugString() + << ") to all cores"; + } + (*arg_sharding)[i] = *sharding; + (*arg_fast_mem)[i] = is_fast_mem; + if (is_fast_mem) { + VLOG(3) << "Add " << TPU_FAST_MEM_ATTR << " attribute to " + << args[i]->name(); + } + args[i]->AddAttr(kShardingAttribute, sharding->SerializeAsString()); + } + TF_RETURN_IF_ERROR(cached_function_handles.ReleaseAllHandles()); + + // Assigns each _Retval node to the core that produces its value. + TensorDevicePlacer retvals_device_selector(num_cores_per_replica, + retval_types, retval_shapes); + retval_sharding->resize(retvals.size()); + for (int i = 0; i < retvals.size(); ++i) { + const Edge* edge; + TF_RETURN_IF_ERROR(retvals[i]->input_edge(0, &edge)); + + TF_ASSIGN_OR_RETURN( + absl::optional<xla::OpSharding> sharding, + ParseShardingFromDevice(*edge->src(), num_cores_per_replica)); + + if (partitioned_output_nodes.contains(i)) { + Node* output_node = partitioned_output_nodes[i]; + TF_ASSIGN_OR_RETURN(absl::optional<xla::OpSharding> parsed_sharding, + GetShardingFromNodeDef(output_node->def())); + if (parsed_sharding.has_value()) { + sharding = parsed_sharding; + VLOG(1) << "Retval " << i << " parsed sharding information from " + << output_node->name() << " : " << sharding->DebugString(); + } + } + absl::optional<int64> assigned_core; + if (sharding.has_value()) { + if (enable_automatic_model_parallelism_) { + return tensorflow::errors::InvalidArgument( + "Specifying manual sharding is not allowed when automatic " + "model parallelism is enabled.", + sharding->DebugString()); + } + + if (sharding.value().type() == xla::OpSharding::MAXIMAL) { + assigned_core = sharding.value().tile_assignment_devices(0); + TF_RETURN_IF_ERROR( + ValidateCoreNumber(*assigned_core, num_cores_per_replica)); + } else if (sharding.value().type() != xla::OpSharding::REPLICATED && + sharding.value().type() != xla::OpSharding::OTHER) { + return tensorflow::errors::InvalidArgument( + "Unsupported argument sharding: ", sharding->DebugString()); + } + } else { + if (use_spmd) { + sharding = xla::sharding_builder::Replicate(); + } else { + if (distribute_vars_) { + assigned_core = retvals_device_selector.RetrieveAssignment(i); + } else { + assigned_core = 0; + } + sharding = xla::sharding_builder::AssignDevice(*assigned_core); + } + } + if (assigned_core.has_value()) { + retvals[i]->set_assigned_device_name(CoreDeviceLabel(*assigned_core)); + retvals_device_selector.ReportDeviceAssigned(*assigned_core, i); + VLOG(3) << "Assigning return value " << i << " (" + << retvals[i]->DebugString() << ") to core " << *assigned_core; + } else if (sharding->type() == xla::OpSharding::OTHER) { + for (int64 core : sharding->tile_assignment_devices()) { + retvals_device_selector.ReportDeviceAssigned(core, i); + VLOG(3) << "Assigning return value " << i << " (" + << retvals[i]->DebugString() << ") with tiled sharding to core " + << core; + } + } else { + CHECK_EQ(sharding->type(), xla::OpSharding::REPLICATED); + for (int64 core = 0; core < num_cores_per_replica; ++core) { + retvals_device_selector.ReportDeviceAssigned(core, i); + } + VLOG(3) << "Assigning return value " << i << " (" + << retvals[i]->DebugString() << ") to all cores."; + } + retvals[i]->AddAttr(kShardingAttribute, sharding->SerializeAsString()); + (*retval_sharding)[i] = *sharding; + } + return Status::OK(); +} + +// Builds Shape nodes that compute the shapes of arguments whose shapes are not +// statically known. +/* static */ Status DistributedTPURewritePass::BuildDynamicShapeNodes( + const Node& replicate_node, const std::vector<InferredShape>& arg_shapes, + const ParameterInfo& params_info, const std::vector<Node*>& variable_reads, + Graph* graph, std::vector<Node*>* dynamic_shape_nodes) { + dynamic_shape_nodes->clear(); + + std::vector<const Edge*> replicate_input_edges; + TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges)); + + // The compiler determines the shape of each constant by inspecting the value + // of its corresponding host-memory tensor; this happens when a step is run. + // As a result, the shapes of constants are not needed at graph rewrite time. + const int num_args = arg_shapes.size() - params_info.NumGuaranteedConstants(); + TF_RET_CHECK(num_args == params_info.NumPerReplicaArgs() + + params_info.NumDistributedArgs() + + params_info.NumBroadcastArgs() + + params_info.NumVariables()); + + for (int i = 0; i < num_args; ++i) { + const PartialTensorShape* shape = arg_shapes[i].handle_type == DT_INVALID + ? &arg_shapes[i].shape + : &arg_shapes[i].handle_shape; + if (!shape->IsFullyDefined()) { + Node* src; + int src_output; + if (params_info.IsPerReplicaArg(i)) { + TF_RET_CHECK(i < replicate_input_edges.size()); + // All replicas must have the same input shapes. Uses the shape of the + // inputs from the first replica. + src = replicate_input_edges[i]->src(); + src_output = replicate_input_edges[i]->src_output(); + } else if (params_info.IsDistributedArg(i) || + params_info.IsBroadcastArg(i)) { + int64 input_num = + params_info.NumPerReplicaArgs() * params_info.NumReplicas() + i - + params_info.NumPerReplicaArgs(); + TF_RET_CHECK(0 <= input_num && + input_num < replicate_input_edges.size()); + src = replicate_input_edges[input_num]->src(); + src_output = replicate_input_edges[input_num]->src_output(); + } else { + int64 var_num = i - params_info.NumPerReplicaArgs() - + params_info.NumDistributedArgs() - + params_info.NumBroadcastArgs(); + TF_RET_CHECK(0 <= var_num && var_num < variable_reads.size()); + src = variable_reads[var_num]; + src_output = 0; + } + + NodeDef def; + def.set_name(graph->NewName(strings::StrCat(src->name(), "/shape"))); + def.set_op("Shape"); + def.set_device(src->assigned_device_name()); + AddNodeAttr("T", src->output_type(src_output), &def); + AddNodeAttr("out_type", DT_INT64, &def); + MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def); + + Status status; + Node* shape_node = graph->AddNode(def, &status); + if (!status.ok()) return status; + dynamic_shape_nodes->push_back(shape_node); + + shape_node->set_assigned_device_name(src->assigned_device_name()); + graph->AddEdge(src, src_output, shape_node, 0); + } + } + return Status::OK(); +} + +// Builds a TPUCompile node that compiles the bodies of the function call +// `nodes`. +Status DistributedTPURewritePass::BuildCompileNode( + const Node* replicate_node, const NameAttrList& function, + uint64 library_fingerprint, const ParameterInfo& params_info, + const std::vector<InferredShape>& arg_shapes, + const DataTypeVector& arg_types, + const std::vector<Node*>& guaranteed_constant_nodes, + const string& session_handle, + const std::vector<xla::OpSharding>& arg_sharding, + const std::vector<bool>& arg_fast_mem, + const std::vector<xla::OpSharding>& retval_sharding, + int num_cores_per_replica, const string& compile_device, + const xla::DeviceAssignment* xla_device_assignment, + const std::vector<Node*>& dynamic_shape_nodes, Graph* graph, + Node** compile_node, int64 autotuner_thresh) { + VLOG(1) << "BuildCompileNode"; + + tpu::TPUCompileMetadataProto proto; + proto.set_num_replicas(params_info.NumReplicas()); + proto.set_num_cores_per_replica(num_cores_per_replica); + proto.set_function_library_fingerprint(library_fingerprint); + proto.set_enable_automatic_model_parallelism( + enable_cross_replica_sharding_mirrored_variables_); + const bool use_spmd = UseSpmdForXlaPartitioning(replicate_node); + proto.set_use_spmd_for_xla_partitioning(use_spmd); + + // Get and fill padding map. + if (replicate_node != nullptr) { + TF_RETURN_IF_ERROR( + FillPaddingMap(*replicate_node, proto.mutable_padding_maps())); + xla::DebugOptions::StepMarkerLocation location; + TF_RETURN_IF_ERROR(GetStepMarkerLocation(*replicate_node, &location)); + proto.set_step_marker_location(location); + } + + if (xla_device_assignment != nullptr) { + TF_RETURN_IF_ERROR( + xla_device_assignment->Serialize(proto.mutable_device_assignment())); + } + + const int num_args = arg_types.size(); + const int num_guaranteed_constants = guaranteed_constant_nodes.size(); + const int guaranteed_const_start_index = num_args - num_guaranteed_constants; + TF_RET_CHECK(num_args == arg_shapes.size()); + TF_RET_CHECK(num_args == arg_sharding.size()) + << num_args << " != " << arg_sharding.size(); + + for (int i = 0; i < num_args; ++i) { + tpu::TPUCompileMetadataProto::Arg* arg = proto.add_args(); + DataType type = arg_types[i]; + const InferredShape& arg_shape = arg_shapes[i]; + if (type == DT_RESOURCE) { + TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) << i; + arg->set_dtype(arg_shape.handle_type); + arg_shape.handle_shape.AsProto(arg->mutable_shape()); + arg->set_kind(tpu::TPUCompileMetadataProto::Arg::VARIABLE); + arg->set_fast_mem(arg_fast_mem[i]); + } else { + arg->set_dtype(type); + arg_shape.shape.AsProto(arg->mutable_shape()); + if (i >= guaranteed_const_start_index) { + const DataType edge_type = + guaranteed_constant_nodes[i - guaranteed_const_start_index] + ->output_type(0); + TF_RET_CHECK(type == edge_type) + << "Arg type: " << type << " but edge type: " << edge_type; + arg->set_kind(tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT); + } else { + arg->set_kind(tpu::TPUCompileMetadataProto::Arg::PARAMETER); + } + } + // As long as the argument is not a per-replica one, it should have the same + // value for all replicas. For clarity, we keep the (redundant) checks for + // variable, broadcast and constant types, to prevent bugs in case new types + // with different semantics are introduced in the future. + arg->set_is_same_data_across_replicas( + !params_info.IsPerReplicaArg(i) && !params_info.IsDistributedArg(i) && + (params_info.IsVariableArg(i) || params_info.IsBroadcastArg(i) || + params_info.IsConstantArg(i))); + if (params_info.mirrored_variable_indices().count(i) > 0) { + CHECK_EQ(type, DT_RESOURCE); + arg->set_is_same_data_across_replicas(true); + // 64-bit type is not shardable by XLA:TPU yet. + bool sharding_enabled = (arg_shape.handle_type != DT_COMPLEX64 && + arg_shape.handle_type != DT_INT64 && + arg_shape.handle_type != DT_UINT64 && + arg_shape.handle_type != DT_DOUBLE); + arg->set_enable_xla_sharding( + sharding_enabled ? tpu::TPUCompileMetadataProto::Arg::TENTATIVE + : tpu::TPUCompileMetadataProto::Arg::DISALLOWED); + } + *arg->mutable_sharding() = arg_sharding[i]; + } + + const int num_retvals = retval_sharding.size(); + for (int i = 0; i < num_retvals; ++i) { + *proto.add_retvals()->mutable_sharding() = retval_sharding[i]; + } + proto.set_session_handle(session_handle); + + DataTypeVector constant_arg_types; + constant_arg_types.reserve(num_guaranteed_constants); + for (int i = 0; i < num_guaranteed_constants; ++i) { + constant_arg_types.push_back(arg_types[guaranteed_const_start_index + i]); + } + proto.set_xla_fusion_autotuner_thresh(autotuner_thresh); + + string metadata; + proto.SerializeToString(&metadata); + + NodeDef def; + def.set_name(UniqueNodeName("TPUReplicate/_compile", graph)); + def.set_op("TPUCompile"); + def.set_device(compile_device); + if (replicate_node) { + MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def); + } + + AddNodeAttr("function", function, &def); + AddNodeAttr("num_computations", num_cores_per_replica, &def); + AddNodeAttr("NumDynamicShapes", static_cast<int>(dynamic_shape_nodes.size()), + &def); + AddNodeAttr("metadata", metadata, &def); + AddNodeAttr("Tguaranteed_constants", constant_arg_types, &def); + + Status status; + *compile_node = graph->AddNode(def, &status); + TF_RETURN_IF_ERROR(status); + + (*compile_node)->set_assigned_device_name(compile_device); + + for (int i = 0; i < dynamic_shape_nodes.size(); ++i) { + graph->AddEdge(dynamic_shape_nodes[i], 0, *compile_node, i); + } + + for (int i = 0; i < num_guaranteed_constants; ++i) { + graph->AddEdge(guaranteed_constant_nodes[i], 0, *compile_node, + dynamic_shape_nodes.size() + i); + } + VLOG(1) << "BuildCompileNode(): " << status; + return status; +} + +Status DistributedTPURewritePass::FindGuaranteedConstantInputs( + const Node& node, const NameRangeMap& input_range_map, + std::vector<Node*>* guaranteed_constants) { + std::vector<const Edge*> input_edges; + TF_RETURN_IF_ERROR(node.input_edges(&input_edges)); + std::pair<int, int> variables_limits = + input_range_map.at("guaranteed_constants"); + for (int i = variables_limits.first; i < variables_limits.second; ++i) { + guaranteed_constants->push_back(input_edges[i]->src()); + } + return Status::OK(); +} + +Status DistributedTPURewritePass::FindVariableInputs( + const Node& node, const NameRangeMap& input_range_map, + std::vector<VariableInput>* variables) { + std::vector<const Edge*> input_edges; + TF_RETURN_IF_ERROR(node.input_edges(&input_edges)); + std::pair<int, int> variables_limits = input_range_map.at("variables"); + for (int i = variables_limits.first; i < variables_limits.second; ++i) { + Node* node = input_edges[i]->src(); + + // Find the type of the VarHandleOp that feeds this node, looking through + // any wrapping Enter or Switch nodes. + while (node->IsEnter() || node->IsSwitch()) { + TF_RETURN_IF_ERROR(node->input_node(0, &node)); + } + // Fix the variable device assignment if it is requested with a full name. + if (!node->has_assigned_device_name() && + !node->requested_device().empty()) { + DeviceNameUtils::ParsedName var_device; + TF_RET_CHECK(DeviceNameUtils::ParseFullName(node->requested_device(), + &var_device)); + if (var_device.has_job && var_device.has_replica && var_device.has_task && + var_device.has_type && var_device.has_id) { + node->set_assigned_device_name(node->requested_device()); + if (node != input_edges[i]->src() && + !input_edges[i]->src()->has_assigned_device_name()) { + input_edges[i]->src()->set_assigned_device_name( + node->requested_device()); + } + } + } + if (node->type_string() == "VarHandleOp") { + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "dtype", &dtype)); + variables->push_back(VariableInput{input_edges[i]->src(), + input_edges[i]->src_output(), dtype}); + } else if (node->type_string() == "_Arg") { + std::vector<DataType> dtypes; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "_handle_dtypes", &dtypes)); + if (dtypes.empty()) { + return errors::Internal( + "_Arg node with resource output must have non-empty _handle_dtypes " + "attribute: ", + node->DebugString()); + } + variables->push_back(VariableInput{ + input_edges[i]->src(), input_edges[i]->src_output(), dtypes[0]}); + } else { + return errors::Internal( + "Cannot handle variable input with node type other than VarHandleOp " + "and _Arg: ", + node->DebugString()); + } + } + return Status::OK(); +} + +// Builds a NoOp node, used for building control dependencies. +static Status BuildNoopNode(const Node& source, StringPiece name, + const string& device, Graph* graph, Node** node) { + NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source)); + if (!device.empty()) { + builder.Device(device); + } + NodeDef def; + TF_RETURN_IF_ERROR(builder.Finalize(&def)); + + Status status; + *node = graph->AddNode(def, &status); + if (!device.empty()) { + (*node)->set_assigned_device_name(device); + } + return status; +} + +Status DistributedTPURewritePass::ConnectHostComputeNodes( + Node* compile_node, Node* key_placeholder_node, Graph* graph) { + // First find all the downstream nodes of the key placeholder node, since we + // want to delete the connecting edges from key_placeholder_node which would + // invalidate the out_nodes iterator. + std::vector<Node*> host_transfer_nodes; + for (Node* node : key_placeholder_node->out_nodes()) { + host_transfer_nodes.push_back(node); + } + for (Node* node : host_transfer_nodes) { + int input_index = -1; + for (int i = 0; i < node->num_inputs(); i++) { + const Edge* e; + TF_RETURN_IF_ERROR(node->input_edge(i, &e)); + if (e->src() == key_placeholder_node) { + if (input_index != -1) { + return errors::Internal( + "Node ", node->name(), + " has multiple input edges from key placeholder node"); + } + input_index = e->dst_input(); + } + } + if (input_index == -1) { + return errors::Internal("Node ", node->name(), + " has no input edge from key placeholder node"); + } + const Edge* key_edge; + TF_RETURN_IF_ERROR(node->input_edge(input_index, &key_edge)); + graph->RemoveEdge(key_edge); + graph->AddEdge(compile_node, 1, node, input_index); + } + graph->RemoveNode(key_placeholder_node); + return Status::OK(); +} + +Status DistributedTPURewritePass::BuildVariableReads( + absl::Span<const VariableInput> variables, Node* control_predecessor, + Graph* graph, std::vector<Node*>* variable_reads) { + variable_reads->resize(variables.size()); + for (int i = 0; i < variables.size(); ++i) { + string name = + graph->NewName(strings::StrCat(variables[i].node->name(), "/read")); + NodeDefBuilder builder(name, "ReadVariableOp", + NodeDebugInfo(*variables[i].node)); + + builder.Attr("dtype", variables[i].dtype); + builder.Device(variables[i].node->assigned_device_name()); + builder.Input(variables[i].node->name(), 0, DT_RESOURCE); + NodeDef def; + TF_RETURN_IF_ERROR(builder.Finalize(&def)); + + Status status; + Node* read_node; + (*variable_reads)[i] = read_node = graph->AddNode(def, &status); + if (!status.ok()) return status; + + read_node->set_requested_device(variables[i].node->requested_device()); + read_node->set_assigned_device_name( + variables[i].node->assigned_device_name()); + graph->AddEdge(variables[i].node, variables[i].index, read_node, 0); + + graph->AddControlEdge(control_predecessor, read_node); + } + return Status::OK(); +} + +bool DistributedTPURewritePass::ContainsResourceWriteOp( + const Graph& graph, const FunctionLibraryDefinition& fld) { + for (const Node* n : graph.nodes()) { + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string()); + if (op_info && op_info->kind() != XlaResourceOpKind::kRead) { + VLOG(2) << "Found write resource op inside computation"; + return true; + } + } + for (const string& func_name : fld.ListFunctionNames()) { + const FunctionDef* func_def = fld.Find(func_name); + for (const NodeDef& n : func_def->node_def()) { + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.op()); + if (op_info && op_info->kind() != XlaResourceOpKind::kRead) { + VLOG(2) << "Found write resource op inside " << func_name; + return true; + } + } + } + return false; +} + +Status DistributedTPURewritePass::BuildVariableWrites( + absl::Span<const VariableInput> variables, Node* control_successor, + absl::Span<const VariableWrite> variable_writes, Graph* graph) { + CHECK_EQ(variables.size(), variable_writes.size()); + for (int i = 0; i < variables.size(); ++i) { + const VariableWrite& write = variable_writes[i]; + NodeDebugInfo debug_info(*variables[i].node); + + auto name = [&](string suffix) { + return graph->NewName( + strings::StrCat(variables[i].node->name(), "/", suffix)); + }; + + Node* write_node; + TF_RETURN_IF_ERROR( + IncompleteNodeDefBuilder(name("assign"), "AssignVariableOp", debug_info) + .AddAttr("dtype", variables[i].dtype) + .Device(variables[i].node->assigned_device_name()) + .Build(graph, &write_node)); + + // Colocate the control flow with the variable. + CondBuilder cb(variables[i].node->name(), + variables[i].node->assigned_device_name(), debug_info, + graph); + + // Inputs to conditional. + Node* switch_val; + TF_RETURN_IF_ERROR( + cb.AddInput("switch_val", variables[i].dtype, + /*device=*/write.value->assigned_device_name(), debug_info, + &switch_val)); + Node* switch_var; + TF_RETURN_IF_ERROR( + cb.AddInput("switch_var", DT_RESOURCE, + /*device=*/variables[i].node->assigned_device_name(), + debug_info, &switch_var)); + // Conditionally write the value back. + graph->AddEdge(variables[i].node, variables[i].index, switch_var, 0); + graph->AddEdge(switch_var, CondBuilder::kThenBranch, write_node, 0); + graph->AddEdge(switch_val, CondBuilder::kThenBranch, write_node, 1); + // Add control edge from the write to value that will be merged. There is no + // output from the write so this control edge ensures the write completes. + graph->AddControlEdge(write_node, cb.switch_t()); + + graph->AddControlEdge(cb.control_successor(), control_successor); + + graph->AddEdge(write.predicate, write.predicate_output, cb.pred(), 0); + graph->AddEdge(write.value, write.value_output, switch_val, 0); + } + return Status::OK(); +} + +namespace { + +// Helper that creates an IdentityN node containing all of the variables +// values on CPU device 'device', except for those that will be split across +// cores. (For split variables, this may cause additional cross-host data +// transfers if more than 1 devices share the same variable partition on a +// remote host.) +// +// A previous iteration of this code built one Identity node per TPU core per +// variable, but this can rapidly become hundreds of thousands of nodes. This +// formulation creates a single IdentityN node containing all of the variables +// on each host. This may cause some unnecessary variable copies if only a +// subset of hosts consume a given variable, but has the virtue of being +// simple, and most models use pure replication where all cores want all the +// variables. +// +// Returns the node and its output index to be consumed by TPUExecute for the +// requested variable index. +xla::StatusOr<NodeOut> CreateOrGetPerHostVariableCopy( + const string& host_cpu_device, int64 var_index, + const std::vector<Node*>& variable_reads, + const DistributedTPURewritePass::ParameterInfo& params_info, + const std::vector<xla::OpSharding>& arg_shardings, + const Node& replicate_node, + absl::flat_hash_map<string, std::vector<NodeOut>>* per_host_var_copies, + Graph* graph) { + auto it = per_host_var_copies->find(host_cpu_device); + if (it != per_host_var_copies->end()) { + return it->second[var_index]; + } + + DataTypeVector dtypes; + // Per-variable data source for TPUExecute. + std::vector<NodeOut> index_mapping; + index_mapping.reserve(variable_reads.size()); + dtypes.reserve(variable_reads.size()); + for (int64 i = 0; i < variable_reads.size(); ++i) { + Node* read = variable_reads[i]; + int64 orig_arg_num = + i + params_info.NumPerReplicaArgs() + params_info.NumBroadcastArgs(); + if (arg_shardings[orig_arg_num].type() != xla::OpSharding::OTHER) { + // We haven't built the IdentityN node yet, so temporarily use nullptr. + index_mapping.push_back( + NodeOut{nullptr, static_cast<int>(dtypes.size())}); + dtypes.push_back(read->output_type(0)); + } else { + // Do not copy the full tensor of partitioned variables. + index_mapping.push_back(NodeOut{read, 0}); + } + } + NodeDef ndef; + ndef.set_name( + graph->NewName(absl::StrCat(replicate_node.name(), "/_variable_copy"))); + ndef.set_op("IdentityN"); + ndef.set_device(host_cpu_device); + AddNodeAttr("T", dtypes, &ndef); + Status s; + Node* id_node = graph->AddNode(ndef, &s); + TF_RETURN_IF_ERROR(s); + id_node->set_assigned_device_name(host_cpu_device); + + for (int64 i = 0; i < variable_reads.size(); ++i) { + if (index_mapping[i].node == nullptr) { + // Fill index_mapping with the actual IdentityN node. + index_mapping[i].node = id_node; + // Add the edge to id_node. + graph->AddEdge(variable_reads[i], 0, id_node, index_mapping[i].index); + } + } + + auto result = index_mapping[var_index]; + (*per_host_var_copies)[host_cpu_device] = std::move(index_mapping); + return result; +} + +} // namespace + +Status DistributedTPURewritePass::BuildExecuteNodes( + const ParameterInfo& params_info, int num_tasks, int num_cores_per_replica, + const Node& replicate_node, const DataTypeVector& arg_types, + const std::vector<InferredShape>& arg_shapes, + const DataTypeVector& retval_types, + const std::vector<xla::OpSharding>& arg_shardings, + const std::vector<xla::OpSharding>& retval_shardings, + const std::vector<std::vector<string>>& tpu_device_names, + Node* compile_node, const std::vector<Node*>& variable_reads, + Node* control_predecessor, Node* control_successor, + std::vector<VariableWrite>* variable_writes, Graph* graph) { + VLOG(1) << "BuildExecuteNodes " << replicate_node.DebugString(); + TF_RET_CHECK(params_info.NumReplicas() == tpu_device_names.size()); + + const int num_variables = variable_reads.size(); + const int num_retvals_per_replica = retval_types.size(); + + variable_writes->resize(num_variables); + + std::vector<const Edge*> replicate_input_edges; + TF_RETURN_IF_ERROR(replicate_node.input_edges(&replicate_input_edges)); + + // Map from replicate input index to the fan_in node; + absl::flat_hash_map<int, std::vector<Node*>> replicate_input_fan_in_nodes; + absl::flat_hash_map<int, std::vector<Node*>> replicate_output_fan_out_nodes; + absl::flat_hash_map<int, std::vector<int>> + replicate_output_fan_out_dst_inputs; + std::vector<Node*> to_be_removed_nodes; + + for (const Edge* e : replicate_input_edges) { + if (e->src()->type_string() == kTPUPartitionedInput) { + int num_users = 0; + for (const auto& ue : e->src()->out_edges()) { + if (!ue->IsControlEdge()) ++num_users; + } + if (num_users != 1) { + return tensorflow::errors::InvalidArgument( + e->src()->name(), " must only have one user. Found ", num_users); + } + to_be_removed_nodes.push_back(e->src()); + std::vector<Node*>& nodes = replicate_input_fan_in_nodes[e->dst_input()]; + nodes.resize(num_cores_per_replica, nullptr); + VLOG(2) << "allocate " << num_cores_per_replica + << " for replicate_input_fan_in_nodes[" << e->dst_input() << "]"; + std::vector<const Edge*> fan_in_edges; + TF_RETURN_IF_ERROR(e->src()->input_edges(&fan_in_edges)); + TF_RET_CHECK(fan_in_edges.size() == num_cores_per_replica); + + for (const Edge* fe : fan_in_edges) { + nodes[fe->dst_input()] = fe->src(); + VLOG(2) << "replicate_input_fan_in_nodes[" << e->dst_input() << "][" + << fe->dst_input() << "] = " << fe->src()->name(); + } + } + } + + // Replicate output edges are sorted by replica id and then by outputs for + // each replica. For example, if TPU Computation has outputs (output_1, + // output_2, and output_3) and number of replicas is 2, then + // replicate_output_edges order would be: + // output_1_replica_1, output_2_replica_1, output_3_replica_1, + // output_1_replica_2, output_2_replica_2, output_3_replica_2. + std::vector<const Edge*> replicate_output_edges(replicate_node.num_outputs(), + nullptr); + for (const Edge* edge : replicate_node.out_edges()) { + if (edge->IsControlEdge()) continue; + + int num_partitioned_outputs = 0; + + for (const Edge* out_edge : edge->dst()->out_edges()) { + if (out_edge->dst()->type_string() == kTPUPartitionedOutput) { + num_partitioned_outputs++; + // Paths between replicate_node and replicate_output_fan_out_nodes: + // ReplicateNode->TpuOutIdenity->kTPUPartitionedOutput->fan-out-nodes + TF_RET_CHECK(edge->dst()->out_edges().size() == 1); + to_be_removed_nodes.push_back(edge->dst()); + to_be_removed_nodes.push_back(out_edge->dst()); + // Get the right replicated id from the replicate_output_edge. + std::vector<Node*>& nodes = + replicate_output_fan_out_nodes[edge->src_output()]; + std::vector<int>& dst_inputs = + replicate_output_fan_out_dst_inputs[edge->src_output()]; + nodes.resize(num_cores_per_replica, nullptr); + dst_inputs.resize(num_cores_per_replica, 0); + TF_RET_CHECK(out_edge->dst()->out_edges().size() == + num_cores_per_replica); + + for (const Edge* fe : out_edge->dst()->out_edges()) { + nodes[fe->src_output()] = fe->dst(); + dst_inputs[fe->src_output()] = fe->dst_input(); + VLOG(2) << "replicate_output_fan_out_nodes[" << out_edge->src_output() + << "][" << fe->src_output() + << "] = " << fe->dst()->DebugString() << " with dst_input " + << fe->dst_input(); + } + } + } + replicate_output_edges[edge->src_output()] = edge; + if (num_partitioned_outputs > 1) { + return errors::InvalidArgument( + "More than one TPUPartitionedOutput per replciated output."); + } + } + + const int num_execute_args = + arg_shardings.size() - params_info.NumGuaranteedConstants(); + // Inverts the arg_shardings and retval_shardings mappings to + // form core -> {argument number} maps. + std::vector<std::vector<int>> core_arg_nums(num_cores_per_replica); + for (int i = 0; i < num_execute_args; ++i) { + const auto& sharding = arg_shardings[i]; + if (sharding.type() == xla::OpSharding::MAXIMAL) { + int core = sharding.tile_assignment_devices(0); + TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica)); + core_arg_nums[core].push_back(i); + } else if (sharding.type() == xla::OpSharding::OTHER) { + for (int64 core : sharding.tile_assignment_devices()) { + core_arg_nums[core].push_back(i); + } + } else if (sharding.type() == xla::OpSharding::REPLICATED) { + for (int core = 0; core < num_cores_per_replica; ++core) { + core_arg_nums[core].push_back(i); + } + } else { + return tensorflow::errors::InvalidArgument( + "Unsupported argument sharding: ", sharding.DebugString()); + } + } + std::vector<std::vector<int>> core_retval_nums(num_cores_per_replica); + for (int i = 0; i < retval_shardings.size(); ++i) { + const auto& sharding = retval_shardings[i]; + if (sharding.type() == xla::OpSharding::MAXIMAL) { + int core = sharding.tile_assignment_devices(0); + TF_RETURN_IF_ERROR(ValidateCoreNumber(core, num_cores_per_replica)); + core_retval_nums[core].push_back(i); + } else if (sharding.type() == xla::OpSharding::REPLICATED) { + for (int core = 0; core < num_cores_per_replica; ++core) { + core_retval_nums[core].push_back(i); + } + } else if (sharding.type() == xla::OpSharding::OTHER) { + for (int64 core : sharding.tile_assignment_devices()) { + core_retval_nums[core].push_back(i); + } + } else { + return tensorflow::errors::InvalidArgument( + "Unsupported argument sharding: ", sharding.DebugString()); + } + } + + // Maps host device name to a list of per-variable pairs (variable_copy_node, + // output_index_of_copy_node). + absl::flat_hash_map<string, std::vector<NodeOut>> per_host_var_copies; + + // Mapping from original resource arg number to a second level map. Second + // level map is from core id to output index of updated variable value. + absl::flat_hash_map<int, absl::flat_hash_map<int, int>> + orig_arg_num_to_output_index_mapping; + // Mapping from retval index to a second level map. Second level map is from + // core id to output index of sharded output value. + std::unordered_map<int, std::unordered_map<int, int>> + retval_index_to_output_index_mapping; + + // Represents mapping of argument index of sharded input to each + // TPUExecute node to its corresponding Split node and its output index + // from which sharded input will be fed into TPUExecute node. + std::map<ShardedInputIndex, ShardedInputInfo> input_index_to_sharded_inputs; + + // Builds one TPUExecute node per core per replica. + std::vector<std::vector<Node*>> execute_nodes(params_info.NumReplicas()); + for (int core = 0; core < num_cores_per_replica; ++core) { + DataTypeVector core_retval_types; + for (int output : core_retval_nums[core]) { + core_retval_types.push_back(retval_types[output]); + } + DataTypeVector core_arg_types; + std::vector<int> core_variable_writes; + for (int input : core_arg_nums[core]) { + // Resource variables can be passed either by reference (as a DT_RESOURCE) + // tensor or by value (as the variable's current value). Per-replica or + // distributed resource arguments are always passed by reference and + // broadcast variables are always passed by value. + if (arg_types[input] == DT_RESOURCE && + !params_info.IsPerReplicaArg(input) && + !params_info.IsDistributedArg(input)) { + DataType handle_type = arg_shapes[input].handle_type; + TF_RET_CHECK(handle_type != DT_INVALID) << DataTypeString(handle_type); + core_arg_types.push_back(handle_type); + int base = input - params_info.NumPerReplicaArgs() - + params_info.NumDistributedArgs() - + params_info.NumBroadcastArgs(); + // Variables passed by value will have a corresponding additional output + // containing an updated value for the variable. + core_variable_writes.push_back(base); + core_retval_types.push_back(handle_type); + } else { + core_arg_types.push_back(arg_types[input]); + } + } + + NodeDef def; + def.set_op("TPUExecute"); + MergeDebugInfo(NodeDebugInfo(replicate_node.def()), &def); + AddNodeAttr("Targs", core_arg_types, &def); + AddNodeAttr("Tresults", core_retval_types, &def); + + for (int64 replica = 0; replica < params_info.NumReplicas(); ++replica) { + def.set_name(strings::StrCat(replicate_node.name(), "/_execute_", replica, + "_", core)); + + Status status; + Node* node = graph->AddNode(def, &status); + if (!status.ok()) return status; + execute_nodes[replica].push_back(node); + + node->set_assigned_device_name(tpu_device_names[replica][core]); + + // Add control edges to ensure that execution happens after + // `control_predecessor`, happens before `control_successor`, and is + // triggered by evaluating any operator that depends on the original + // TPUReplicate operator. See the comment at the top of the header file + // for more details. + graph->AddControlEdge(control_predecessor, node); + graph->AddControlEdge(node, control_successor); + + // Add data input edges. + for (int64 i = 0; i < core_arg_nums[core].size(); ++i) { + int64 orig_arg_num = core_arg_nums[core][i]; + VLOG(2) << " replica " << replica << " core " << core << " i " << i + << " orig_arg_num " << orig_arg_num; + if (params_info.IsPerReplicaArg(orig_arg_num) || + params_info.IsDistributedArg(orig_arg_num)) { + // Per-replica input and distributed input + int64 input_num = params_info.IsPerReplicaArg(orig_arg_num) + ? replica * params_info.NumPerReplicaArgs() + + core_arg_nums[core][i] + : params_info.NumReplicas() * + params_info.NumPerReplicaArgs() + + core_arg_nums[core][i] - + params_info.NumPerReplicaArgs(); + + const Edge* edge = replicate_input_edges[input_num]; + VLOG(2) << "replicate_input_edges[" << input_num << "]"; + DataType dtype = edge->src()->output_type(edge->src_output()); + if (dtype == DT_RESOURCE) { + DataType handle_dtype = arg_shapes[orig_arg_num].handle_type; + if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), + handle_dtype) == kTpuAllTypes.end()) { + return errors::InvalidArgument( + "Unsupported resource variable data type for TPU: ", + DataTypeString(handle_dtype), ", caused by output ", + edge->src()->name(), ":", edge->src_output()); + } + } else { + if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) == + kTpuAllTypes.end()) { + return errors::InvalidArgument( + "Unsupported data type for TPU: ", DataTypeString(dtype), + ", caused by output ", edge->src()->name(), ":", + edge->src_output()); + } + } + if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) { + // Don't automatically add a split node when input node is + // kTPUPartitionedInput + if (edge->src()->type_string() == kTPUPartitionedInput) { + VLOG(2) << "Connect " + << replicate_input_fan_in_nodes[input_num][core]->name() + << " to " << node->name() << " at " << i; + graph->AddEdge(replicate_input_fan_in_nodes[input_num][core], 0, + node, i); + } else { + if (dtype == DT_RESOURCE) { + return errors::InvalidArgument( + "Tiled sharding for per-replica DT_RESOURCE input must", + "be TPUPartitionedInput. Here got ", + edge->src()->type_string()); + } + const xla::OpSharding& sharding = arg_shardings[orig_arg_num]; + + // Create or get the Split node. + TF_ASSIGN_OR_RETURN( + ShardedInputInfo sharded_input_info, + CreateOrGetSplitNodesForInputSharding( + sharding, orig_arg_num, dtype, replica, + edge->src_output(), edge->src(), control_predecessor, + graph, &input_index_to_sharded_inputs)); + + // Calculate which output we should receive from the Split node. + absl::optional<int> output_index = + GetCoreIndexInSharding(sharding, core); + TF_RET_CHECK(output_index); + + NodeOut split_node_and_index = + sharded_input_info.sharded_inputs.at(output_index.value()); + // Connect with Split node output. + graph->AddEdge(split_node_and_index.node, + split_node_and_index.index, node, i); + } + } else if (edge->src()->type_string() == kTPUPartitionedInput && + arg_shardings[orig_arg_num].type() == + xla::OpSharding::REPLICATED) { + graph->AddEdge(replicate_input_fan_in_nodes[input_num][core], 0, + node, i); + } else { + graph->AddEdge(edge->src(), edge->src_output(), node, i); + } + } else if (params_info.IsBroadcastArg(orig_arg_num)) { + // Broadcast input. + int64 input_num = params_info.FirstBroadcastArgFromHost() + + core_arg_nums[core][i] - + params_info.NumPerReplicaArgs() - + params_info.NumDistributedArgs(); + const Edge* edge = replicate_input_edges[input_num]; + DataType dtype = edge->src()->output_type(edge->src_output()); + if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) == + kTpuAllTypes.end()) { + return errors::InvalidArgument( + "Unsupported data type for TPU: ", DataTypeString(dtype), + ", caused by output ", edge->src()->name(), ":", + edge->src_output()); + } + graph->AddEdge(edge->src(), edge->src_output(), node, i); + } else { + // Variable input. + int64 variable_num = orig_arg_num - params_info.NumPerReplicaArgs() - + params_info.NumDistributedArgs() - + params_info.NumBroadcastArgs(); + TF_RET_CHECK(variable_num < num_variables); + + Node* variable_read = variable_reads[variable_num]; + DataType dtype = variable_read->output_type(0); + if (std::find(kTpuAllTypes.begin(), kTpuAllTypes.end(), dtype) == + kTpuAllTypes.end()) { + return errors::InvalidArgument( + "Unsupported resource variable data type for TPU: ", + DataTypeString(dtype), ", caused by ReadVariableOp ", + variable_read->DebugString()); + } + DeviceNameUtils::ParsedName requested_device; + string requested = variable_read->requested_device(); + TF_RET_CHECK( + DeviceNameUtils::ParseFullName(requested, &requested_device)); + if (requested_device.type != "TPU") { + // Stage the value via the CPU device on the remote host. The graph + // partitioner will introduce an intermediate copy rather than + // copying the same tensor multiple times across the network, and we + // would prefer that intermediate copy to be in host memory to avoid + // running out of memory if the TPUExecute op on the staging device + // starts running before the _Send ops to the other TPU devices on + // the same host complete. We don't do this if the variables are + // already placed on TPU, otherwise it will cause an unnecessary + // round trip copy. + // TODO(b/79580121): give each replica its own on-device variable + // replica and then delete this code. + string device; + TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( + tpu_device_names[replica][core], &device)); + TF_ASSIGN_OR_RETURN(auto var_data, + CreateOrGetPerHostVariableCopy( + device, variable_num, variable_reads, + params_info, arg_shardings, replicate_node, + &per_host_var_copies, graph)); + + if (arg_shardings[orig_arg_num].type() == xla::OpSharding::OTHER) { + const xla::OpSharding& sharding = arg_shardings[orig_arg_num]; + // Create or get the Split node. + TF_ASSIGN_OR_RETURN( + ShardedInputInfo sharded_input_info, + CreateOrGetSplitNodesForInputSharding( + sharding, orig_arg_num, + arg_shapes[orig_arg_num].handle_type, replica, + var_data.index, var_data.node, control_predecessor, graph, + &input_index_to_sharded_inputs)); + + // Calculate which output we should receive from the Split node. + absl::optional<int> output_index = + GetCoreIndexInSharding(sharding, core); + TF_RET_CHECK(output_index); + NodeOut split_node_and_index = + sharded_input_info.sharded_inputs[output_index.value()]; + // Connect with Split node output. + graph->AddEdge(split_node_and_index.node, + split_node_and_index.index, node, i); + + } else { + graph->AddEdge(var_data.node, var_data.index, node, i); + } + } else { + graph->AddEdge(variable_reads[variable_num], 0, node, i); + } + } + } + + // Adds a program input edge from the compiler. + graph->AddEdge(compile_node, core + 1, node, node->num_inputs() - 1); + + // Add data output edges. + int num_outputs = core_retval_nums[core].size(); + for (int i = 0; i < num_outputs; ++i) { + int output_num = + replica * num_retvals_per_replica + core_retval_nums[core][i]; + const auto& sharding = retval_shardings[core_retval_nums[core][i]]; + if (sharding.type() == xla::OpSharding::OTHER) { + int retval_index = core_retval_nums[core][i]; + retval_index_to_output_index_mapping[retval_index][core] = i; + bool is_last_core = + core == + *std::max_element(sharding.tile_assignment_devices().begin(), + sharding.tile_assignment_devices().end()); + bool isPartitionOutNode = false; + + const Edge* e = replicate_output_edges[output_num]; + const Edge* e_out; + for (const Edge* out_edge : e->dst()->out_edges()) { + if (out_edge->dst()->type_string() == kTPUPartitionedOutput) { + isPartitionOutNode = true; + e_out = out_edge; + } + } + if (isPartitionOutNode) { + graph->AddEdge( + node, i, replicate_output_fan_out_nodes[output_num][core], + replicate_output_fan_out_dst_inputs[output_num][core]); + VLOG(2) << "Connect " << node->name() << " at " << i << " to " + << replicate_output_fan_out_nodes[output_num][core]->name() + << " at " + << replicate_output_fan_out_dst_inputs[output_num][core]; + if (is_last_core) { + graph->RemoveEdge(e); + graph->RemoveEdge(e_out); + } + continue; + } + + // Do this in the iteration of last core in tile assignment, so all + // TPUExecute nodes have been created. + if (!is_last_core) { + continue; + } + + // Add a Concat node. + std::vector<NodeOut> orig_inputs; + for (int64 core_id : sharding.tile_assignment_devices()) { + int core_retval_index = + retval_index_to_output_index_mapping[retval_index][core_id]; + orig_inputs.push_back( + NodeOut{execute_nodes[replica][core_id], + static_cast<int>( + core_retval_nums[core_id][core_retval_index])}); + } + DataType dtype = e->src()->output_type(e->src_output()); + TF_ASSIGN_OR_RETURN( + Node * concat_node, + CreateConcatNodesForRetval(sharding, dtype, replica, orig_inputs, + graph, /*device=*/"")); + + const Edge* edge = replicate_output_edges[output_num]; + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + graph->AddEdge(concat_node, 0, dst, dst_input); + + continue; + } + + // If this is a replicated output, outputs on all cores will be the + // same, and we only take the output from core 0. + if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) { + continue; + } + + // If output has maximal sharding, make sure we only use output from + // TPUExecute node with logical core id equal to core id defined by the + // xla sharding. + if (sharding.type() == xla::OpSharding::MAXIMAL && + core != sharding.tile_assignment_devices(0)) { + continue; + } + + const Edge* replicate_edge_to_replace = + replicate_output_edges[output_num]; + Node* dst = replicate_edge_to_replace->dst(); + int dst_input = replicate_edge_to_replace->dst_input(); + graph->RemoveEdge(replicate_edge_to_replace); + graph->AddEdge(node, i, dst, dst_input); + } + + // Feed the updated variable values from the first replica to the + // variable write nodes. + if (replica == 0) { + for (int i = 0; i < core_variable_writes.size(); ++i) { + int orig_arg_num = + core_variable_writes[i] + params_info.NumPerReplicaArgs() + + params_info.NumDistributedArgs() + params_info.NumBroadcastArgs(); + const auto& sharding = arg_shardings[orig_arg_num]; + // If this is a tiling sharded variable, concat variable updates from + // all cores. + if (sharding.type() == xla::OpSharding::OTHER) { + orig_arg_num_to_output_index_mapping[orig_arg_num][core] = i; + + // Do this in the iteration of last core in tile assignment, so all + // TPUExecute nodes have been created. + if (core != + *std::max_element(sharding.tile_assignment_devices().begin(), + sharding.tile_assignment_devices().end())) { + continue; + } + + // Add a Concat node. + std::vector<NodeOut> orig_inputs; + for (int64 core_id : sharding.tile_assignment_devices()) { + int core_retval_num = + orig_arg_num_to_output_index_mapping[orig_arg_num][core_id]; + orig_inputs.push_back( + NodeOut{execute_nodes[0][core_id], + static_cast<int>(core_retval_nums[core_id].size() + + core_retval_num)}); + } + + // Use the variable read's device for the concat. They should both + // be collocated with the variable. + absl::string_view device = + variable_reads[core_variable_writes[i]]->assigned_device_name(); + TF_ASSIGN_OR_RETURN( + Node * concat_node, + CreateConcatNodesForRetval( + sharding, arg_shapes[orig_arg_num].handle_type, replica, + orig_inputs, graph, device)); + // Populate VariableWrite. + VariableWrite& write = variable_writes->at(core_variable_writes[i]); + write.value = concat_node; + write.value_output = 0; + write.predicate = compile_node; + write.predicate_output = num_cores_per_replica + core + 1; + + continue; + } + + // If this is a replicated variable, outputs on all cores will be the + // same, and we only take the output from core 0 for the varialbe + // update. + if (sharding.type() == xla::OpSharding::REPLICATED && core != 0) { + continue; + } + VariableWrite& write = variable_writes->at(core_variable_writes[i]); + write.value = node; + write.value_output = num_outputs + i; + write.predicate = compile_node; + write.predicate_output = num_cores_per_replica + core + 1; + } + } + } + } + + for (Node* node : to_be_removed_nodes) { + graph->RemoveNode(node); + } + return Status::OK(); +} + +/* static */ Status DistributedTPURewritePass::CopyOutsideCompilationNodes( + int replica_index, const std::vector<Node*>& outside_compilation_nodes, + const DeviceNameUtils::ParsedName& tpu_device, + const DeviceNameUtils::ParsedName& partial_device, + NodeToNodeReplicasMap* node_images, Graph* graph) { + for (Node* node : outside_compilation_nodes) { + NodeDef image_def = node->def(); + MergeDebugInfo(NodeDebugInfo(node->def()), &image_def); + const string suffix = strings::StrCat("/R", replica_index); + // In addition to node name, make the frame name unique to avoid multiple + // LoopCond nodes in one frame. + TF_RETURN_IF_ERROR( + AddPrefixAndSuffixToNode("" /* prefix */, suffix, &image_def)); + Status status; + Node* image = graph->AddNode(image_def, &status); + image->AddAttr(kXlaReplicaIdAttrName, replica_index); + TF_RETURN_IF_ERROR(status); + if (HasNodeAttr(image->def(), kXlaHasHostTransferAttrName)) { + TF_RETURN_IF_ERROR( + SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, image)); + } else { + const string& original_device_string = + node->assigned_device_name().empty() ? node->requested_device() + : node->assigned_device_name(); + DeviceNameUtils::ParsedName device; + TF_RET_CHECK( + DeviceNameUtils::ParseFullName(original_device_string, &device)); + // If the requested device can be merged with the replica's host device, + // then do so. For example, if the requested device is "/CPU:0" or + // "/GPU:0" then it will be placed on the CPU/GPU of the host where this + // replica is running. But if the requested device is + // "/task:3/replica:2/CPU:0" then it will be placed on that task/replica. + if (DeviceNameUtils::IsSpecification(device, partial_device)) { + TF_RETURN_IF_ERROR( + DeviceNameUtils::MergeDevNames(&device, partial_device)); + } + image->set_requested_device(DeviceNameUtils::ParsedNameToString(device)); + } + std::vector<Node*>& node_image_vector = (*node_images)[node]; + node_image_vector.resize(replica_index + 1); + node_image_vector[replica_index] = image; + } + return Status::OK(); +} + +/* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationNodes( + const std::vector<std::vector<string>>& tf_device_assignment, + const HostComputeCoreMap& host_compute_core, + const OutsideCompilationNodeMap& outside_compilation_nodes, + NodeToNodeReplicasMap* node_images, Graph* graph) { + // Iterate over replicas. + for (int i = 0; i < tf_device_assignment.size(); ++i) { + const auto& core_devices = tf_device_assignment[i]; + for (const auto& oc_cluster_iter : outside_compilation_nodes) { + const string& oc_cluster_name = oc_cluster_iter.first; + const auto& oc_cluster_nodes = oc_cluster_iter.second; + // We previously validated that host_compute_core contains an entry for + // each cluster. + int core = host_compute_core.at(oc_cluster_name); + TF_RET_CHECK(core >= 0 && core < core_devices.size()); + // tpu_device is the device the HostCompute XLA Op for this cluster runs + // on. + DeviceNameUtils::ParsedName tpu_device; + TF_RET_CHECK( + DeviceNameUtils::ParseFullName(core_devices[core], &tpu_device)); + // partial_device contains the replica and task but not the type. + DeviceNameUtils::ParsedName partial_device = tpu_device; + partial_device.has_type = false; + partial_device.has_id = false; + + if (tf_device_assignment.size() == 1) { + // With a single replica don't copy any nodes just put the original + // nodes into the image map. We leave the device placement alone, except + // that we have to fill in the correct core for the host send and + // receive nodes. + for (Node* node : oc_cluster_nodes) { + (*node_images)[node] = {node}; + node->AddAttr(kXlaReplicaIdAttrName, 0); + if (HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { + TF_RETURN_IF_ERROR( + SetNodeDeviceForTPUCommunication(tpu_device, DEVICE_CPU, node)); + } + } + } else { + // Iterate over outside_compilation clusters in this computation, adding + // all the nodes with appropriate device assignments. + TF_RETURN_IF_ERROR( + CopyOutsideCompilationNodes(i, oc_cluster_nodes, tpu_device, + partial_device, node_images, graph)); + } + } + } + return Status::OK(); +} + +/* static */ Status DistributedTPURewritePass::CopyOutsideCompilationEdges( + const std::vector<Node*>& outside_compilation_nodes, + const NodeToNodeReplicasMap& node_images, + const std::unordered_map<string, Node*> outside_compilation_inputs, + Graph* graph) { + for (Node* node : outside_compilation_nodes) { + const auto& images = node_images.at(node); + // Make a copy of all edges and iterate on "in_edges", because we might + // remove edges when iteratating through them. + std::vector<const Edge*> in_edges(node->in_edges().begin(), + node->in_edges().end()); + for (const Edge* edge : in_edges) { + Node* src = edge->src(); + const auto iter = node_images.find(src); + if (iter == node_images.end()) { + if (images.size() > 1) { + // The source node is a 'normal' node not part of any + // rewrite. Broadcast the value to all replicas. (If images.size() == + // 1 the cluster is not replicated and we can leave the original edge + // in place.) + for (Node* dst : images) { + graph->AddEdge(src, edge->src_output(), dst, edge->dst_input()); + } + } + continue; + } + + // The source node is a replicated outside_compilation node. + const auto& src_images = iter->second; + if (src_images.size() != images.size()) { + return errors::InvalidArgument( + "Graph contains an edge from node ", src->name(), + " in an outside_compilation block replicated ", src_images.size(), + " ways to node ", node->name(), + " in an outside_compilation block replicated ", images.size(), + " ways. Replication factors must match. Leave a comment on " + "tracking bug b/76419636 if you need this to be supported."); + } + bool is_lifted_arg; + string outside_compilation_cluster; + if (GetNodeAttr(src->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) + .ok() && + GetNodeAttr(src->def(), kOutsideCompilationAttr, + &outside_compilation_cluster) + .ok()) { + const auto input_iter = + outside_compilation_inputs.find(outside_compilation_cluster); + TF_RET_CHECK(input_iter != outside_compilation_inputs.end()); + TF_RET_CHECK(input_iter->second->type_string() == "IdentityN"); + int dst_input = edge->dst_input(); + if (src_images.size() == 1) { + graph->RemoveEdge(edge); + } + for (int i = 0; i < src_images.size(); ++i) { + graph->AddEdge(input_iter->second, i, images[i], dst_input); + } + continue; + } + + bool is_placeholder_for_arg; + string outside_compilation_input_attr; + if (GetNodeAttr(src->def(), kXlaIsPlaceholderForArg, + &is_placeholder_for_arg) + .ok() && + GetNodeAttr(src->def(), kXlaOutsideCompilationInputsAttrName, + &outside_compilation_input_attr) + .ok()) { + const auto input_iter = + outside_compilation_inputs.find(outside_compilation_input_attr); + TF_RET_CHECK(input_iter != outside_compilation_inputs.end()); + TF_RET_CHECK(input_iter->second->type_string() == "IdentityN"); + int dst_input = edge->dst_input(); + if (src_images.size() == 1) { + graph->RemoveEdge(edge); + } + for (int i = 0; i < src_images.size(); ++i) { + graph->AddEdge(input_iter->second, i, images[i], dst_input); + } + continue; + } + + if (images.size() > 1) { + // If images.size() == 1 neither cluster is replicated and we can + // leave the original edges in place. + for (int i = 0; i < src_images.size(); ++i) { + graph->AddEdge(src_images[i], edge->src_output(), images[i], + edge->dst_input()); + } + } + } + for (const Edge* edge : node->out_edges()) { + Node* dst = edge->dst(); + const auto iter = node_images.find(dst); + if (iter == node_images.end()) { + // The source node is a 'normal' node not part of any rewrite. + if (edge->IsControlEdge()) { + // Make the dst node have a control dependency on every replica. + if (images.size() > 1) { + for (int i = 0; i < images.size(); ++i) { + graph->AddControlEdge(images[i], dst); + } + } + // else the cluster is not replicated so we can leave the original + // edge in place. + } else { + // The edge + // is only valid if the outside_compilation block is not replicated. + if (images.size() > 1) { + return errors::InvalidArgument( + "Graph contains an edge from node ", node->name(), + " in an outside_compilation block replicated ", images.size(), + " ways to node ", dst->name(), + " that is not part of an outside_compilation block. Edges from " + "outside_compilation to regular graph nodes are only supported " + "for replication factors of 1. Leave a comment on tracking bug " + "b/76419636 if you need this to be supported."); + } + // else the cluster is not replicated so we can leave the original + // edge in place. + } + } + // The case where src and dst are both in node_images is covered elsewhere + // when iterating over in_edges of dst. + } + } + return Status::OK(); +} + +/* static */ Status DistributedTPURewritePass::ReplicateOutsideCompilationEdges( + const OutsideCompilationNodeMap& outside_compilation_nodes, + const NodeToNodeReplicasMap& node_images, + const std::unordered_map<string, Node*> outside_compilation_inputs, + Graph* graph) { + for (const auto& oc_cluster_iter : outside_compilation_nodes) { + TF_RETURN_IF_ERROR( + CopyOutsideCompilationEdges(oc_cluster_iter.second, node_images, + outside_compilation_inputs, graph)); + } + return Status::OK(); +} + +/* static */ Status DistributedTPURewritePass::RemoveOutsideCompilationNodes( + const NodeToNodeReplicasMap& node_images, Graph* graph) { + for (const auto& iter : node_images) { + if (iter.second.size() > 1) { + // The cluster was replicated so remove the original node. + Node* node = iter.first; + graph->RemoveNode(node); + } + } + return Status::OK(); +} + +/* static */ Status +DistributedTPURewritePass::LowerOutsideCompilationFunctionalNodes( + Graph* g, const FunctionLibraryDefinition& flib_def, + const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping) { + bool modified = false; + do { + std::vector<Node*> nodes_to_lower; + for (Node* n : g->op_nodes()) { + if (!HasNodeAttr(n->def(), kOutsideCompilationAttr)) { + continue; + } + + if (n->IsWhileNode() || n->IsIfNode() || IsFunctionCall(flib_def, *n)) { + // Only lower functional ops with DT_RESOURCE input, because otherwise + // placer will complain. For normal cases, lowering will cause slowdown + // when related functions are huge (b/139037679). + bool has_resource_input = false; + for (const Edge* e : n->in_edges()) { + if (!e->IsControlEdge() && + e->src()->output_type(e->src_output()) == DT_RESOURCE) { + has_resource_input = true; + break; + } + } + if (has_resource_input) { + nodes_to_lower.push_back(n); + } + } + } + + modified = !nodes_to_lower.empty(); + + auto lower_functional_node = [&flib_def, &g](Node* n) -> Status { + // Clear device assignment. Otherwise all lowered nodes will have + // device assignment, which is not what we want. + n->set_requested_device(""); + + int replica_id; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id)); + + string outside_compilation_attr; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kOutsideCompilationAttr, + &outside_compilation_attr)); + + // There are two different kinds of functional outside compilation nodes: + // 1. Nodes that are in outside compilation blocks already. They are + // generated by FunctionalizeControlFlowForXlaPass, and only have + // attribute kOutsideCompilationAttr. + // 2. Mirrored control flow built for outside compilation in functional + // nodes. They are generated by ExtractOutsideCompilationPass, and have + // both kOutsideCompilationAttr and kXlaHasHostTransferAttrName. + // When lowering them, they need to be treated differently. + // For 1), their body functions are always V1 functions written by users, + // and their "control outputs" are control inputs of _Retval nodes. They + // should be lowered as V1 functions. + // For 2), we always add necessary "control outputs" + // (_XlaRecvAtHost/_XlaSendAtHost nodes) to "control_ret" field in their + // FunctionDef's. They should be lowered as V2 functions. + bool is_host_side_mirrored_control_flow = + HasNodeAttr(n->def(), kXlaHasHostTransferAttrName); + + int num_node_ids = g->num_node_ids(); + bool is_call_node = IsFunctionCall(flib_def, *n); + if (n->IsWhileNode()) { + TF_RETURN_IF_ERROR(RewriteWhileNode(n, g, + /*keep_node_fetchable=*/false)); + } else if (n->IsIfNode()) { + TF_RETURN_IF_ERROR(RewriteIfNode(n, g, /*keep_node_fetchable=*/false)); + } else { + TF_RET_CHECK(is_call_node); + // See comments for "is_host_side_mirrored_control_flow" above. + // If this is a node that's in outside compilation block, lower it as + // V1 function. This is controlled by removing + // kLowerAsMultiDeviceFunctionAttr from the node. + if (!is_host_side_mirrored_control_flow) { + n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr); + } else { + n->ClearAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr); + n->AddAttr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr, + true); + } + TF_RETURN_IF_ERROR( + RewriteFunctionCallNode(n, g, flib_def, + /*keep_caller_fetchable=*/false)); + } + + for (int i = num_node_ids; i < g->num_node_ids(); i++) { + Node* node = g->FindNodeId(i); + if (!node) { + continue; + } + + if (!is_call_node && is_host_side_mirrored_control_flow && + IsFunctionCall(flib_def, *node)) { + // For If/While nodes, if they are host side mirrored control flow, + // mark their body function calls with kXlaHasHostTransferAttrName + // attribute to make sure we lower them as V2 function. + node->AddAttr(kXlaHasHostTransferAttrName, true); + } + + if (IsFunctionCall(flib_def, *node) || node->IsWhileNode() || + node->IsIfNode()) { + // Set kOutsideCompilationAttr attribute so we lower these + // nested function call nodes later. + node->AddAttr(kOutsideCompilationAttr, outside_compilation_attr); + // Set kXlaReplicaIdAttrName attribute so we know replica id when we + // lower this function call node. + node->AddAttr(kXlaReplicaIdAttrName, replica_id); + } else if (node->type_string() == "_XlaRecvAtHost" || + node->type_string() == "_XlaSendFromHost") { + // For "_XlaRecvAtHost" and "_XlaSendFromHost" nodes, make sure they + // have kXlaReplicaIdAttrName attribute so later we know which host + // device to assign. + node->AddAttr(kXlaReplicaIdAttrName, replica_id); + } + } + return Status::OK(); + }; + + for (Node* n : nodes_to_lower) { + TF_RETURN_IF_ERROR(lower_functional_node(n)); + } + } while (modified); + + // Set device for all _XlaRecvAtHost and _XlaSendFromHost nodes. + for (Node* n : g->op_nodes()) { + if (n->type_string() != "_XlaRecvAtHost" && + n->type_string() != "_XlaSendFromHost") { + continue; + } + + string replicate; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kTPUReplicateAttr, &replicate)); + auto iter = tpu_replicate_device_names_mapping.find(replicate); + TF_RET_CHECK(iter != tpu_replicate_device_names_mapping.end()); + const auto& tpu_device_names = iter->second; + + int replica_id; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->def(), kXlaReplicaIdAttrName, &replica_id)); + TF_RET_CHECK(replica_id < tpu_device_names.size()); + const string& tpu_device_name = tpu_device_names[replica_id][0]; + string host_device_name; + TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( + tpu_device_name, &host_device_name)); + n->set_assigned_device_name(host_device_name); + // We may run TPU rewrite passes again on the subgraphs of the resulting + // graph. Clear kTPUReplicateAttr and kOutsideCompilationAttr for + // "_XlaRecvAtHost" nodes and "_XlaSendFromHost" nodes, in order to make + // sure that TPU rewrite passes take no effect on host-side subgraphs for + // outside compilation. + n->ClearAttr(kTPUReplicateAttr); + n->ClearAttr(kOutsideCompilationAttr); + } + + // Remove IdentityN nodes generated for outside compilation. IdentityN is + // exempt from resource edge colocation, but here we do need input and output + // for these IdentityN nodes to be colocated. + std::vector<Node*> identityn_nodes; + for (Node* n : g->op_nodes()) { + if (n->type_string() == "IdentityN" && + HasNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName)) { + identityn_nodes.push_back(n); + } + } + for (Node* n : identityn_nodes) { + std::vector<const Edge*> out_edges(n->out_edges().begin(), + n->out_edges().end()); + for (const Edge* e : out_edges) { + if (e->IsControlEdge()) { + continue; + } + + int src_output = e->src_output(); + const Edge* input_edge; + TF_RETURN_IF_ERROR(n->input_edge(src_output, &input_edge)); + Node* dst = e->dst(); + int dst_input = e->dst_input(); + g->RemoveEdge(e); + g->AddEdge(input_edge->src(), input_edge->src_output(), dst, dst_input); + } + g->RemoveNode(n); + } + + return Status::OK(); +} + +/* static */ Status DistributedTPURewritePass::ParseHostComputeCores( + const Node& replicate_node, + const OutsideCompilationNodeMap& outside_compilation_nodes, + HostComputeCoreMap* host_compute_core) { + std::vector<string> hc_core_string; + TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "host_compute_core", + &hc_core_string)); + TF_RETURN_IF_ERROR( + ParseHostComputeCoreList(hc_core_string, host_compute_core)); + for (const auto& iter : outside_compilation_nodes) { + const string& oc_cluster_name = iter.first; + if (host_compute_core->find(oc_cluster_name) == host_compute_core->end()) { + // By default put host compute Ops on replicated core 0. + (*host_compute_core)[oc_cluster_name] = 0; + } + } + return Status::OK(); +} + +/* static */ Status DistributedTPURewritePass::GetDeviceTopology( + const DeviceSet& device_set, const Node& replicate_node, int* num_replicas, + int* num_cores_per_replica, int* num_tasks, + std::vector<std::vector<string>>* tf_device_assignment, + std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment, + string* tpu_compilation_device) { + TF_RETURN_IF_ERROR( + GetNodeAttr(replicate_node.attrs(), "num_replicas", num_replicas)); + if (*num_replicas < 1) { + return errors::InvalidArgument("num_replicas must be >= 1, got ", + *num_replicas); + } + + // Find the set of TPU devices in the TF job. + // Indexed by [task number][tpu device number]. + std::vector<std::vector<Device*>> tpu_devices; + int num_tpus_per_task; + TF_RETURN_IF_ERROR(GetTPUDeviceNames(replicate_node.requested_device(), + device_set, tpu_compilation_device, + &num_tpus_per_task, &tpu_devices)); + + string topology; + TF_RETURN_IF_ERROR( + GetNodeAttr(replicate_node.attrs(), "topology", &topology)); + TF_RETURN_IF_ERROR(GetNodeAttr( + replicate_node.attrs(), "num_cores_per_replica", num_cores_per_replica)); + std::vector<int> device_assignment; + TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "device_assignment", + &device_assignment)); + + // TODO(cwhipkey): since we can control multiple pods of different shapes + // from a single worker, it may be desirable to propagate the remote device + // information around (e.g., in DeviceAttributes). This can lead to the mesh + // topology proto being leaked to cloud TPU users (e.g. through GetStatus + // calls); this may be okay, but to be conservative, just assume that the + // master session has the proper flags set. + + // We do not initialize platform right now, but we can still retrieve the + // TPU topology even with an uninitialized platform. + auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform( + /*initialize_platform=*/false); + TF_RET_CHECK(tpu_platform); + tpu::TpuTopologyExternal tpu_topology(tpu_platform->GetTopologyPtr()); + TF_RET_CHECK(num_tpus_per_task == + tpu_topology.LogicalDevicesPerHost(kTensorCore)); + TF_RETURN_IF_ERROR(BuildDeviceAssignment( + tpu_topology, num_tpus_per_task, tpu_devices, *num_replicas, + *num_cores_per_replica, topology, device_assignment, tf_device_assignment, + xla_device_assignment)); + + return Status::OK(); +} + +/* static */ Status DistributedTPURewritePass::GetIOTypes( + int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr, + Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function, + std::unique_ptr<Graph>* computation, DataTypeVector* arg_types, + DataTypeVector* retval_types, ParameterInfo* params_info) { + DataTypeVector input_types, broadcast_input_types, guaranteed_constant_types; + TF_RETURN_IF_ERROR( + GetNodeAttr(replicate_node.attrs(), "Tinputs", &input_types)); + TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), "Tbroadcast_inputs", + &broadcast_input_types)); + TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), + "Tguaranteed_constants", + &guaranteed_constant_types)); + int num_distributed_vars; + TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), + "num_distributed_variables", + &num_distributed_vars)); + const int num_per_replica_inputs = input_types.size() - num_distributed_vars; + + if (num_per_replica_inputs % num_replicas != 0) { + return errors::InvalidArgument( + "Number of inputs to TPUReplicate (", num_per_replica_inputs, + ") is not divisible by the number of replicas (", num_replicas, ")."); + } + + int num_variables; + TF_RETURN_IF_ERROR( + GetNodeAttr(replicate_node.attrs(), "NumVariables", &num_variables)); + + NameRangeMap output_name_map; + TF_RETURN_IF_ERROR(NameRangesForNode(replicate_node, replicate_node.op_def(), + input_name_map, &output_name_map)); + + TF_RETURN_IF_ERROR( + GetNodeAttr(replicate_node.attrs(), "computation", function)); + + *computation = absl::make_unique<Graph>(graph->op_registry()); + TF_RETURN_IF_ERROR(GetComputationForTPUReplicateOp( + **function, flr, computation->get(), arg_types, retval_types)); + + *params_info = ParameterInfo( + num_replicas, num_per_replica_inputs / num_replicas, num_distributed_vars, + broadcast_input_types.size(), num_variables, + guaranteed_constant_types.size(), retval_types->size()); + + if (arg_types->size() != params_info->NumInputsToEachReplica()) { + return errors::InvalidArgument( + "Computation argument to TPUReplicate has wrong number of " + "arguments. Expected ", + params_info->NumInputsToEachReplica(), " inputs, got ", + arg_types->size()); + } + if (replicate_node.num_outputs() != params_info->NumOutputsToHost()) { + return errors::InvalidArgument( + "Wrong number of outputs from TPUReplicate. Expected ", + params_info->NumOutputsToHost(), " outputs, got ", + replicate_node.num_outputs()); + } + if (enable_cross_replica_sharding_mirrored_variables_) { + std::vector<int> mirrored_variable_indices; + TF_RETURN_IF_ERROR(GetNodeAttr(replicate_node.attrs(), + TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR, + &mirrored_variable_indices)); + for (int index : mirrored_variable_indices) { + TF_RET_CHECK(params_info->IsPerReplicaArg(index) || + params_info->IsDistributedArg(index)) + << "Mirrored variables not categorized as per-replica arguments, " + "index: " + << index; + params_info->mutable_mirrored_variable_indices()->insert(index); + } + } + return Status::OK(); +} + +/* static */ Status DistributedTPURewritePass::BuildSequencingNodes( + const string& tpu_compilation_device, const Node& replicate_node, + Graph* graph, Node** host_transfer_sequencer, Node** control_before, + Node** control_after) { + *host_transfer_sequencer = nullptr; + + TF_RETURN_IF_ERROR( + BuildNoopNode(replicate_node, + graph->NewName(strings::StrCat(replicate_node.name(), "/", + "control_before")), + /*device=*/"", graph, control_before)); + for (const Edge* e : replicate_node.in_edges()) { + if (!e->IsControlEdge()) { + continue; + } + Node* predecessor = e->src(); + if (predecessor->IsSource()) continue; + if (predecessor->type_string() == "NoOp" && + predecessor->attrs().Find("_xla_host_transfer_sequencer") != nullptr) { + // The node is the sequencer for host transfer operations. Its control + // dependency needs to be placed after the execute node, not before. + if (*host_transfer_sequencer != nullptr) { + return errors::Internal("Replicate node ", replicate_node.name(), + " has two transfer sequencer nodes: ", + (*host_transfer_sequencer)->name(), " and ", + predecessor->name()); + } + // Set the correct device to match the other sequencing nodes. + predecessor->set_assigned_device_name(tpu_compilation_device); + *host_transfer_sequencer = predecessor; + } else { + graph->AddControlEdge(predecessor, *control_before); + } + } + + TF_RETURN_IF_ERROR( + BuildNoopNode(replicate_node, + graph->NewName(strings::StrCat(replicate_node.name(), "/", + "control_after")), + /*device=*/tpu_compilation_device, graph, control_after)); + for (Node* successor : replicate_node.out_nodes()) { + if (successor->attrs().Find("_xla_tail_outside_compilation") != nullptr) { + graph->AddControlEdge(successor, *control_after); + } else { + graph->AddControlEdge(*control_after, successor); + } + } + return Status::OK(); +} + +/* static */ Status DistributedTPURewritePass::DealWithConstantsAndVariables( + const Node& replicate_node, const NameRangeMap& input_name_map, + Graph* graph, Node* host_transfer_sequencer, Node* control_before, + Node* control_after, absl::Span<const VariableInput> variable_nodes, + std::vector<Node*>* guaranteed_constant_nodes, + std::vector<Node*>* variable_reads) { + TF_RETURN_IF_ERROR(FindGuaranteedConstantInputs( + replicate_node, input_name_map, guaranteed_constant_nodes)); + + TF_RETURN_IF_ERROR(BuildVariableReads(variable_nodes, control_before, graph, + variable_reads)); + // Add the control dependency from host transfer nodes. + if (host_transfer_sequencer != nullptr) { + graph->AddControlEdge(host_transfer_sequencer, control_after); + } + return Status::OK(); +} + +/* static */ Status +DistributedTPURewritePass::BuildCompilationStatusReturnNodes( + Node* replicate_node, Node* compile_node, Node** control_after_compilation, + Graph* graph) { + const Edge* compilation_edge = nullptr; + for (const auto* e : replicate_node->out_edges()) { + if (e->IsControlEdge() && + e->dst()->type_string() == "TPUCompilationResult") { + TF_RET_CHECK(compilation_edge == nullptr) + << "Multiple compilation result nodes attached to the same replicate " + "cluster."; + compilation_edge = e; + } + } + + // TODO(jpienaar): This should be checked by default, current tests not using + // this are ones that use the "abort upon successful compile flag" which will + // be removed. Leaving this in until then. + if (compilation_edge != nullptr) { + Node* compilation_status = compilation_edge->dst(); + const AttrValue* compile_status_cluster_attr = + compilation_status->attrs().Find(kTPUCompilationResultAttr); + TF_RET_CHECK(compile_status_cluster_attr != nullptr); + const string& compile_status_cluster = compile_status_cluster_attr->s(); + TF_RET_CHECK(!compile_status_cluster.empty()); + const AttrValue* replicate_cluster_attr = + replicate_node->attrs().Find(kTPUReplicateAttr); + TF_RET_CHECK(replicate_cluster_attr != nullptr); + const string& replicate_cluster = replicate_cluster_attr->s(); + TF_RET_CHECK(!replicate_cluster.empty()); + TF_RET_CHECK(compile_status_cluster == replicate_cluster); + + TF_RETURN_IF_ERROR( + ReplaceCompilationResultNodeWithIdentity(graph, &compilation_status)); + graph->AddEdge(compile_node, 0, compilation_status, 0); + } + + NodeDef def; + def.set_name(UniqueNodeName("tpu_compile_succeeded_assert", graph)); + // Create an op to assert that compilation succeeded. The alternative would + // have been to have each execute op check and return an error. + def.set_op("TPUCompileSucceededAssert"); + MergeDebugInfo(NodeDebugInfo(replicate_node->def()), &def); + Status status; + Node* compile_succeeded = graph->AddNode(def, &status); + compile_succeeded->set_assigned_device_name( + compile_node->assigned_device_name()); + TF_RETURN_IF_ERROR(status); + graph->AddEdge(compile_node, 0, compile_succeeded, 0); + + // Build a sequencing node for when compilation has completed. + TF_RETURN_IF_ERROR( + BuildNoopNode(*replicate_node, + graph->NewName(strings::StrCat(compile_node->name(), "/", + "after_compilation")), + /*device=*/"", graph, control_after_compilation)); + graph->AddControlEdge(compile_succeeded, *control_after_compilation); + + return Status::OK(); +} + +// Updates the head and tail outside compiled nodes so that nodes have the +// correct device and removes the replication and outside compilation attributes +// so that these nodes do not trigger further graph optimization passes. +/* static */ Status DistributedTPURewritePass::UpdateHeadTailOutsideCompilation( + const std::vector<std::vector<string>>& tf_device_assignment, + const std::vector<Node*>& head_tail_outside_compilation_nodes) { + for (Node* node : head_tail_outside_compilation_nodes) { + int replica_id; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->def(), kXlaReplicaIdAttrName, &replica_id)); + // Since we set the device, this will now run on a task other than 0. We + // clear the two following attributes so that we don't trigger encapsulation + // again on the remote host (which will fail due to a missing + // _TPUReplicateMetadata node for the cluster). + for (const Edge* e : node->in_edges()) { + // Resource consuming ops should colocate with its resource input. + if (e->src()->IsArg() && + e->src()->output_type(e->src_output()) == DT_RESOURCE) { + node->set_requested_device(tf_device_assignment[replica_id][0]); + } + } + if (node->requested_device().empty()) { + string cpu_device; + TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( + tf_device_assignment[replica_id][0], &cpu_device)); + node->set_requested_device(cpu_device); + } + node->ClearAttr(kTPUReplicateAttr); + node->ClearAttr(kOutsideCompilationAttr); + } + return Status::OK(); +} + +/* static */ +Status DistributedTPURewritePass::FingerprintFunctionLibrary( + const FunctionLibraryDefinition& library, uint64* fingerprint) { + // TODO(phawkins): rather than fingerprinting the entire function library, + // consider fingerprinting just the transitive dependencies of a + // computation. + std::string serialized; + FunctionDefLibrary library_proto = library.ToProto(); + if (library_proto.ByteSizeLong() >= 1.5 * 1024 * 1024 * 1024) { + LOG(WARNING) << "Serializing large proto, size: " + << library_proto.ByteSizeLong(); + } + TF_RET_CHECK(SerializeToStringDeterministic(library_proto, &serialized)); + *fingerprint = TpuCompileInterface::Get()->FingerprintString(serialized); + return Status::OK(); +} + +// Performs the rewrite on a single TPUReplicate node. +/* static */ Status DistributedTPURewritePass::RewriteTPUReplicateNode( + const string& session_handle, const DeviceSet& device_set, + Node* replicate_node, FunctionLibraryDefinition* flib_def, + FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node, + const OutsideCompilationNodeMap& outside_compilation_nodes, + const std::vector<Node*>& head_tail_outside_compilation_nodes, + NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph, + const GraphShapeInfo& shape_info, + TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping, + int64 autotuner_thresh) { + VLOG(2) << "Rewriting node " << replicate_node->name(); + + // num_replicas and num_cores_per_replica are the 'virtual' replicas (copies + // of the computation) and cores (virtual cores within computations) specified + // by the user. They will be mapped to physical TPU cores below. + int num_replicas; + int num_cores_per_replica; + int num_tasks; // Number of tasks. + std::vector<std::vector<string>> tf_device_assignment; + std::unique_ptr<xla::DeviceAssignment> xla_device_assignment; + string tpu_compilation_device; + TF_RETURN_IF_ERROR(GetDeviceTopology( + device_set, *replicate_node, &num_replicas, &num_cores_per_replica, + &num_tasks, &tf_device_assignment, &xla_device_assignment, + &tpu_compilation_device)); + + TF_RETURN_IF_ERROR(UpdateHeadTailOutsideCompilation( + tf_device_assignment, head_tail_outside_compilation_nodes)); + + string replicate; + TF_RETURN_IF_ERROR( + GetNodeAttr(replicate_node->def(), kTPUReplicateAttr, &replicate)); + tpu_replicate_device_names_mapping->emplace(replicate, tf_device_assignment); + + NameRangeMap input_name_map; + const NameAttrList* function; + std::unique_ptr<Graph> computation; + DataTypeVector arg_types, retval_types; + ParameterInfo params_info; + TF_RETURN_IF_ERROR(GetIOTypes(num_replicas, *replicate_node, flr, graph, + &input_name_map, &function, &computation, + &arg_types, &retval_types, ¶ms_info)); + + std::vector<InferredShape> arg_shapes, retval_shapes; + TF_RETURN_IF_ERROR(GetArgAndRetvalShapes( + shape_info, *replicate_node, params_info, &arg_shapes, &retval_shapes)); + + TF_RETURN_IF_ERROR(ValidateCoreNumbers(*computation, num_cores_per_replica)); + + std::vector<xla::OpSharding> arg_sharding; + std::vector<bool> arg_fast_mem; + std::vector<xla::OpSharding> retval_sharding; + TF_RETURN_IF_ERROR(AssignArgsAndRetvalsToCores( + num_cores_per_replica, params_info, arg_types, arg_shapes, retval_types, + retval_shapes, *computation, replicate_node, flr, &arg_sharding, + &arg_fast_mem, &retval_sharding)); + + VLOG(1) << DumpGraphToFile("distributed_tpu_graph_to_replicate", *computation, + flib_def); + + GraphDef graph_def; + graph->ToGraphDef(&graph_def); + FunctionLibraryDefinition reachable_functions = + flib_def->ReachableDefinitions(graph_def); + uint64 library_fingerprint; + + TF_RETURN_IF_ERROR( + FingerprintFunctionLibrary(reachable_functions, &library_fingerprint)); + VLOG(1) << "Fingerprint functions: " + << absl::StrJoin(reachable_functions.ListFunctionNames(), ", "); + VLOG(1) << "library_fingerprint: " << library_fingerprint; + + // Builds trigger nodes that put barriers around the expansion of + // TPUReplicate. In particular, we must guarantee: + // a) variable reads happen after all predecessors of the original + // TPUReplicate. + // b) variable writes happen before all successors of the original + // TPUReplicate. + // c) all replicas execute, even if output tensors are only requested from + // a subset of replicas. This is necessary both to ensure that variable + // updates happen, but also Send/Recv will deadlock if only one half of + // the communicating pair runs. + Node* host_transfer_sequencer; + Node* control_before; + Node* control_after; + TF_RETURN_IF_ERROR(BuildSequencingNodes( + tpu_compilation_device, *replicate_node, graph, &host_transfer_sequencer, + &control_before, &control_after)); + + // Build a vector of variable nodes that are inputs. + std::vector<VariableInput> variable_inputs; + TF_RETURN_IF_ERROR( + FindVariableInputs(*replicate_node, input_name_map, &variable_inputs)); + + std::vector<Node*> guaranteed_constant_nodes; + std::vector<Node*> variable_reads; + TF_RETURN_IF_ERROR(DealWithConstantsAndVariables( + *replicate_node, input_name_map, graph, host_transfer_sequencer, + control_before, control_after, variable_inputs, + &guaranteed_constant_nodes, &variable_reads)); + + // Builds Shape nodes that compute the dynamic shapes of arguments whose + // shapes are not statically known. + std::vector<Node*> dynamic_shape_nodes; + TF_RETURN_IF_ERROR(BuildDynamicShapeNodes(*replicate_node, arg_shapes, + params_info, variable_reads, graph, + &dynamic_shape_nodes)); + + // Builds a TPUCompile node that compiles `clusters` on `compile_device`. + Node* compile_node; + TF_RETURN_IF_ERROR(BuildCompileNode( + replicate_node, *function, library_fingerprint, params_info, arg_shapes, + arg_types, guaranteed_constant_nodes, session_handle, arg_sharding, + arg_fast_mem, retval_sharding, num_cores_per_replica, + /*compile_device=*/tpu_compilation_device, xla_device_assignment.get(), + dynamic_shape_nodes, graph, &compile_node, autotuner_thresh)); + + // Compilation must be sequenced after the control node if the TPU computation + // in a control-flow construct, such as a loop. + graph->AddControlEdge(control_before, compile_node); + + Node* control_after_compilation; + TF_RETURN_IF_ERROR(BuildCompilationStatusReturnNodes( + replicate_node, compile_node, &control_after_compilation, graph)); + + std::vector<VariableWrite> variable_writes; + TF_RETURN_IF_ERROR(BuildExecuteNodes( + params_info, num_tasks, num_cores_per_replica, *replicate_node, arg_types, + arg_shapes, retval_types, arg_sharding, retval_sharding, + tf_device_assignment, compile_node, variable_reads, + control_after_compilation, control_after, &variable_writes, graph)); + bool contains_resource_write_op = + ContainsResourceWriteOp(*graph, reachable_functions); + + VLOG(2) << "contains_resource_write_op: " << contains_resource_write_op; + // Skip conditional write if there is no resource writing op inside TPU + // computation. + if (contains_resource_write_op) { + TF_RETURN_IF_ERROR(BuildVariableWrites(variable_inputs, control_after, + variable_writes, graph)); + } + + if (host_compute_key_placeholder_node != nullptr) { + TF_RETURN_IF_ERROR(ConnectHostComputeNodes( + compile_node, host_compute_key_placeholder_node, graph)); + } + + HostComputeCoreMap host_compute_core; + TF_RETURN_IF_ERROR(ParseHostComputeCores( + *replicate_node, outside_compilation_nodes, &host_compute_core)); + TF_RETURN_IF_ERROR(ReplicateOutsideCompilationNodes( + tf_device_assignment, host_compute_core, outside_compilation_nodes, + outside_compilation_node_images, graph)); + + graph->RemoveNode(replicate_node); + return Status::OK(); +} + +// Adds sharded weight update optimization for each host training loop. +// +// For any host training loop found in the graph, TPUVariableReshard ops +// are inserted to match the best layout chosen by the XLA. +/* static */ Status +DistributedTPURewritePass::PerformHostTrainingLoopOptimization( + Graph* graph, FunctionLibraryDefinition* flib_def, + FunctionLibraryRuntime* flr) { + std::vector<tpu::HostTrainingLoopInfo> host_training_loops_info; + Status s = tpu::DetectHostTrainingLoop( + /*current_function_name=*/nullptr, + /*current_function_attr=*/nullptr, flib_def, graph, flr, + &host_training_loops_info); + if (!s.ok()) { + VLOG(2) << "No valid host training loop found. Skipping sharded weight " + << "update optimization."; + return Status::OK(); + } + + for (const auto& host_loop : host_training_loops_info) { + const auto& function_name = host_loop.encapsulating_function_name; + // `function_name` has value when host training loop is inside a + // function call node. When host training loop is found inside a function + // call node, then, in addition to adding TPUVariableReshard ops, function + // library definition needs to be updated as well. + if (function_name.has_value()) { + const auto& function_attr = host_loop.encapsulating_function_attrs; + TF_RET_CHECK(function_attr.has_value()) + << "Unable to find function attribute for function: " + << *function_name; + + const FunctionDef* function_def = flib_def->Find(*function_name); + TF_RET_CHECK(function_def) + << "Unable to find function : " << *function_name; + + std::unique_ptr<FunctionBody> fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *function_def, AttrSlice(&function_attr.value()), flib_def, &fbody)); + Graph* function_graph = fbody->graph; + TF_RETURN_IF_ERROR(tpu::AddReshardOp(function_graph, host_loop)); + TF_RETURN_IF_ERROR(UpdateFunctionLibDefinition(*function_graph, + *function_name, flib_def)); + } else { + TF_RETURN_IF_ERROR(tpu::AddReshardOp(graph, host_loop)); + } + } + return Status::OK(); +} + +Status DistributedTPURewritePass::PlaceUnassignedDeviceNodesOnTPUIfPossible( + Graph* graph) { + ReverseDFS(*graph, {}, PlaceOpsOnTPU); + return Status::OK(); +} + +Status DistributedTPURewritePass::Run( + const GraphOptimizationPassOptions& options) { + VLOG(1) << "DistributedTPURewritePass::Run"; + + Graph* graph = options.graph->get(); + + VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_before", *graph, + options.flib_def); + + const auto* config = &options.session_options->config; + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( + new ProcessFunctionLibraryRuntime( + nullptr, options.session_options->env, config, + graph->versions().producer(), options.flib_def, + config ? config->graph_options().optimizer_options() + : OptimizerOptions())); + + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + // This pass can only run in the session master, which should fill + // in the device_set field to the options. + TF_RET_CHECK(options.device_set != nullptr); + + // Find all the replicate nodes before mutating the graph. + std::vector<Node*> replicate_nodes; + // Map from compiled subgraph cluster name to the outside_compilation nodes in + // that cluster. + std::map<string, OutsideCompilationNodeMap> outside_compilation_nodes; + std::map<string, std::vector<Node*>> head_tail_outside_compilation_nodes; + TF_RETURN_IF_ERROR(FindTaggedNodes(graph, &replicate_nodes, + &outside_compilation_nodes, + &head_tail_outside_compilation_nodes)); + + if (replicate_nodes.empty()) { + // Remove unused TPUPartitionedInput nodes. + for (Node* n : graph->nodes()) { + if (n->type_string() == kTPUPartitionedInput) graph->RemoveNode(n); + } + return Status::OK(); + } + + std::unordered_map<string, Node*> host_compute_key_placeholder_map; + TF_RETURN_IF_ERROR(FindHostComputeKeyPlaceholderNodes( + graph, replicate_nodes, &host_compute_key_placeholder_map)); + + GraphShapeInfo shape_info; + TF_RETURN_IF_ERROR(InferShapes(graph, /*arg_shapes=*/{}, + flr->GetFunctionLibraryDefinition(), + &shape_info)); + int64 autotuner_thresh = options.session_options->config.experimental() + .xla_fusion_autotuner_thresh(); + + NodeToNodeReplicasMap outside_compilation_node_images; + TPUReplicateDeviceNamesMapping tpu_replicate_device_names_mapping; + for (Node* node : replicate_nodes) { + TF_RETURN_IF_ERROR(RewriteTPUReplicateNode( + options.session_handle, *options.device_set, node, options.flib_def, + flr, host_compute_key_placeholder_map[node->name()], + outside_compilation_nodes[node->name()], + head_tail_outside_compilation_nodes[node->name()], + &outside_compilation_node_images, graph, shape_info, + &tpu_replicate_device_names_mapping, autotuner_thresh)); + } + + // Place the padding nodes generated by dynamic padder on the correct devices. + // TODO(rxsang): Place padding ops on TPUs in + // PlaceUnassignedDeviceNodesOnTPUIfPossible function. + TF_RETURN_IF_ERROR(SetPaddingNodesDevices(graph)); + + std::unordered_map<string, Node*> outside_compilation_inputs; + for (Node* n : graph->op_nodes()) { + string lifted_arg_inputs_attr; + if (n->type_string() == "IdentityN" && + GetNodeAttr(n->def(), kXlaOutsideCompilationInputsAttrName, + &lifted_arg_inputs_attr) + .ok()) { + outside_compilation_inputs[lifted_arg_inputs_attr] = n; + } + } + for (const auto& iter : outside_compilation_nodes) { + TF_RETURN_IF_ERROR(ReplicateOutsideCompilationEdges( + iter.second, outside_compilation_node_images, + outside_compilation_inputs, graph)); + } + TF_RETURN_IF_ERROR( + RemoveOutsideCompilationNodes(outside_compilation_node_images, graph)); + TF_RETURN_IF_ERROR(LowerOutsideCompilationFunctionalNodes( + graph, *options.flib_def, tpu_replicate_device_names_mapping)); + + TF_RETURN_IF_ERROR(PlaceUnassignedDeviceNodesOnTPUIfPossible(graph)); + VLOG(1) << DumpGraphToFile("distributed_tpu_compilation_after", *graph, + options.flib_def); + VLOG(1) << "DistributedTPURewritePass::Run() finished"; + + if (enable_cross_replica_sharding_mirrored_variables_) { + VLOG(1) << "Starting host training loop optimization."; + VLOG(1) << DumpGraphToFile("host_loop_optimization_before", *graph, + options.flib_def); + TF_RETURN_IF_ERROR( + PerformHostTrainingLoopOptimization(graph, options.flib_def, flr)); + VLOG(1) << DumpGraphToFile("host_loop_optimization_after", *graph, + options.flib_def); + VLOG(1) << "Host training loop optimization finished."; + } + + return Status::OK(); +} + +bool DistributedTPURewritePass::distribute_vars_ = false; +bool DistributedTPURewritePass:: + replicate_inputs_outputs_by_default_for_xla_spmd_ = false; +bool DistributedTPURewritePass:: + enable_cross_replica_sharding_mirrored_variables_ = true; +bool DistributedTPURewritePass::enable_automatic_model_parallelism_ = false; + +/*static*/ void DistributedTPURewritePass::SetDistributedTpuRewritePassOptions( + bool distribute_vars, bool replicate_inputs_outputs_by_default_for_xla_spmd, + bool enable_cross_replica_sharding_mirrored_variables, + bool enable_automatic_model_parallelism) { + distribute_vars_ = distribute_vars; + replicate_inputs_outputs_by_default_for_xla_spmd_ = + replicate_inputs_outputs_by_default_for_xla_spmd; + enable_cross_replica_sharding_mirrored_variables_ = + enable_cross_replica_sharding_mirrored_variables; + enable_automatic_model_parallelism_ = enable_automatic_model_parallelism; +} + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h new file mode 100644 index 00000000000..52fae7a7c13 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h @@ -0,0 +1,589 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Rewrites TPUReplicate nodes into replicated computations on TPU. +// +// To represent a distributed TPU computation, we use the +// TPUReplicate operator, that describes a subgraph (represented as a +// Tensorflow function) to replicate across a TPU pod. +// +// Model parallelism and data parallelism: +// --------------------------------------- +// We support two different kinds of parallelism on TPU: +// * data parallelism (replication), or parallelization across batches, and +// * model parallelism, or parallelization within a batch. +// +// The function passed to a TPUReplicate operator is replicated many +// times across a TPU pod (data parallelism). The `num_replicas` attribute +// controls how many replicas of the computation to create. Replicas are mostly +// independent; replicas can only communicate using the CrossReplicaSum +// operator, which is typically used to communicate gradients during training. +// +// Each replica may optionally use more than one TPU core (model +// parallelism). The `num_cores_per_replica` attribute controls how many cores +// there are per replica. For each core, there is a virtual TPU_REPLICATED_CORE +// device that is only valid within replicated TPU computations (e.g., +// TPU_REPLICATED_CORE:0, TPU_REPLICATED_CORE:1, etc.); each TPU_REPLICATED_CORE +// device corresponds to one TPU core in every replica. +// Each replica has runs its own copy of the computation assigned to each +// TPU_REPLICATED_CORE device. +// +// The Python code is responsible for providing a device_assignment that +// describes how the replicated logical cores map to physical cores on the TPU +// topology. +// +// Inputs to TPUReplicate: +// ------------------------------ +// The TPUReplicate operator takes three kinds of inputs, in the +// following order: +// * per-replica inputs. If there are three per-replica inputs (A, B, C) and two +// replicas, the first six arguments to TPUReplicate will be: +// A0 B0 C0 A1 B1 C1 +// where Ai is the A input to the i-th replica. +// * distributed inputs. These inputs follow the per-replica inputs. +// If there are two distributed inputs (E, F) and two replicas, the following +// arguments to TPUReplicate will be: E F. +// But there is local E and F on each replica. +// * broadcast inputs. These inputs follow the distributed inputs. All +// replicas receive a copy of each of these inputs. +// * variables. Resource variables accessed by the computation follow the +// broadcast inputs. +// +// For example, for a computation with two replicas, three per-replica inputs +// (A, B, C), two distributed inputs(E, F), two broadcast inputs (X, Y), and two +// variables (V, W), the arguments to TPUReplicate will be: +// A0 B0 C0 A1 B1 C1 E F X Y V W +// and each replica will receive the following arguments: +// A B C E F X Y V W +// +// Distributed TPU compilation requires that the shapes of all operators +// be known statically at compilation time, before any nodes have executed. +// Shapes are determined using shape information emitted by InferShapes. It +// is not possible to replicate Tensorflow operators with unknown or dynamic +// shapes for TPU at present. +// +// Graph rewrite: +// -------------- +// Compilation replaces TPUReplicate operators with: +// * a single TPUCompile node that compiles the computations, +// * one TPUExecute node for each TPU device in the system that +// executes the relevant computation, +// * one ReadVariableOp for each variable accessed by the replicated +// computation, +// * one AssignVariableOp for each variable accessed by the replicated +// computation. An assignment is built even if a variable is only read by the +// computation. We do not know which variables are written until we apply the +// XlaCompiler to the computation, but that does not happen until after the +// rewrite. Conservatively, we write back the values of all variables after +// the computation completes. +// TODO(phawkins): only write back variables that the computation may write. +// * one Shape node for each Tensor or Variable input to the computation whose +// shape is not statically known at rewrite time. The input shapes are fed +// to the TPUCompile node. +// +// To ensure that the reads and writes seem to happen at the right time in the +// graph execution, we add control edges from all predecessors of the original +// TPUReplicate operator to each of the ReadVariableOp operators. +// Similarly, we add control edges from all of the AssignVariableOp operators to +// all of the successors of the TPUReplicate operator. +// +// The TPUReplicate rewrite must run before placement, since resource +// variable inputs will have DT_RESOURCE, which cannot be sent across devices, +// leading to objections from the placer. The rewrite rewrites the resource +// accesses into explicit ReadVariableOp and AssignVariableOp operators that the +// placer is free to colocate with the variables. + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ + +#include <string> +#include <vector> + +#include "absl/container/node_hash_map.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/stream_executor/tpu/tpu_topology.h" + +namespace tensorflow { + +// Replaces clusters assigned to TPU_SYSTEM devices with +// TPUCompile and TPUExecute nodes assigned to the corresponding +// TPU devices. +class DistributedTPURewritePass : public GraphOptimizationPass { + public: + static void SetDistributedTpuRewritePassOptions( + bool distribute_vars, + bool replicate_inputs_outputs_by_default_for_xla_spmd, + bool enable_cross_replica_sharding_mirrored_variables, + bool enable_automatic_model_parallelism); + + Status Run(const GraphOptimizationPassOptions& options) override; + + // The following methods are public only for the use of unit tests. + + // See comment at the top of the file for how the inputs are ordered. + // Encapsulates the different TPU replicated node input and output + // information, and provide common APIs over them. + class ParameterInfo { + public: + ParameterInfo() {} + ParameterInfo(int64 num_replicas, int64 num_per_replica_args, + int64 num_distributed_args, int64 num_broadcast_args, + int64 num_variables, int64 num_guaranteed_constants, + int64 num_retvals_per_replica) + : num_replicas_(num_replicas), + num_per_replica_args_(num_per_replica_args), + num_distributed_args_(num_distributed_args), + num_broadcast_args_(num_broadcast_args), + num_variables_(num_variables), + num_guaranteed_constants_(num_guaranteed_constants), + num_retvals_per_replica_(num_retvals_per_replica) {} + + int64 NumReplicas() const { return num_replicas_; } + + int64 NumPerReplicaArgs() const { return num_per_replica_args_; } + + int64 NumDistributedArgs() const { return num_distributed_args_; } + + int64 NumBroadcastArgs() const { return num_broadcast_args_; } + + int64 NumVariables() const { return num_variables_; } + + int64 NumGuaranteedConstants() const { return num_guaranteed_constants_; } + + int64 NumRetvalsPerReplica() const { return num_retvals_per_replica_; } + + bool IsPerReplicaArg(int64 index) const { + return index < num_per_replica_args_; + } + + bool IsDistributedArg(int64 index) const { + return index >= num_per_replica_args_ && + index < (num_per_replica_args_ + num_distributed_args_); + } + + bool IsBroadcastArg(int64 index) const { + return index >= num_per_replica_args_ && + index < (num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_); + } + + bool IsVariableArg(int64 index) const { + return index >= (num_per_replica_args_ + num_broadcast_args_) && + index < (num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_ + num_variables_); + } + + bool IsConstantArg(int64 index) const { + return index >= (num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_ + num_variables_) && + index < (num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_ + num_variables_ + + num_guaranteed_constants_); + } + + // Returns the number of inputs which has been received by the host. + int64 NumInputsFromHost() const { + return num_replicas_ * num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_ + num_variables_ + num_guaranteed_constants_; + } + + // Returns the number of inputs which will be sent to each replica. + int64 NumInputsToEachReplica() const { + return num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_ + num_variables_ + num_guaranteed_constants_; + } + + // Returns the total number of output values returned to the host (for all + // replicas). + int64 NumOutputsToHost() const { + return num_replicas_ * num_retvals_per_replica_; + } + + // Returns the position of the first per-replica argument, within the set + // of all hosts arguments. + // Broadcast arguments follow the distributed arguments. + int64 FirstBroadcastArgFromHost() const { + return num_replicas_ * num_per_replica_args_ + num_distributed_args_; + } + + // Indices of mirrored variables across replicas, which should be + // categorized as per_replica_args. + const std::set<int64>& mirrored_variable_indices() const { + return mirrored_variable_indices_; + } + std::set<int64>* mutable_mirrored_variable_indices() { + return &mirrored_variable_indices_; + } + + private: + int64 num_replicas_ = 1; + int64 num_per_replica_args_ = 0; + int64 num_distributed_args_ = 0; + int64 num_broadcast_args_ = 0; + int64 num_variables_ = 0; + int64 num_guaranteed_constants_ = 0; + int64 num_retvals_per_replica_ = 0; + std::set<int64> mirrored_variable_indices_; + }; + + // Mapping from TPUReplicate cluster name to tpu device names. Value is a + // mapping from [replica][core] to a TF device name. + typedef absl::flat_hash_map<string, std::vector<std::vector<string>>> + TPUReplicateDeviceNamesMapping; + + // Determines which devices to use to run the computation. + // Inputs: + // * num_tpus_per_task: the number of TPU devices attached to each task + // * tpu_devices: a [task][device] collection of TPU devices + // * num_replicas: the number of replicas requested + // * num_cores_per_replica: the number of cores in each computation instance + // * topology_attr: the topology TPUReplicate attribute + // * device_assignment_attr: the device_assignment TPUReplicate attribute + // Outputs: + // * tf_device_assignment: a mapping from [replica][core] to a TF device name + // * xla_device_assignment: a mapping from [replica][core] to a linearized TPU + // coordinate. + // TODO(phawkins): change tf_device_assignment to an xla::Array2D. + static Status BuildDeviceAssignment( + const tpu::TpuTopologyExternal& topology, int num_tpus_per_task, + const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas, + int num_cores_per_replica, const string& topology_attr, + absl::Span<const int> device_assignment_attr, + std::vector<std::vector<string>>* tf_device_assignment, + std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment); + + // Returns the `computation` graph attached to TPUReplicate operator + // `node`. `flr` is a FunctionLibraryRuntime to use when + // instantiating the function body. Sets `*arg_types` and + // `*retval_types` to the argument/return types of the function. + static Status GetComputationForTPUReplicateOp(const NameAttrList& function, + FunctionLibraryRuntime* flr, + Graph* computation, + DataTypeVector* arg_types, + DataTypeVector* retval_types); + + // Returns the shapes of the argument tensors and return values of the + // TPUReplicate operator `node` using the _output_shapes, + // _output_handle_shapes, and _output_handle_types annotations on the input + // nodes. Expects inputs in the following order (see comment at top of file): + // * num_replicas * num_per_replica_args per-replica inputs, + // * num_broadcast_args broadcast inputs, + // * num_variables variable inputs. + // Returns an error if the input shapes to `node` are not statically known. + // Also verifies that all replicas have identical input shapes for their + // per-replica inputs. + static Status GetArgAndRetvalShapes( + const GraphShapeInfo& shape_info, const Node& node, + const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes, + std::vector<InferredShape>* retval_shapes); + + // Assigns arguments and return values to cores. The assignment is represented + // as an XLA op sharding, so that an argument can be replicated across cores. + // `arg_sharding` and `retval_sharding` are vectors of shardings indexed by + // argument/retval number. + // `arg_fast_mem` is vector of fast_mem indication which is indexed by + // argument number. + static Status AssignArgsAndRetvalsToCores( + int num_cores_per_replica, const ParameterInfo& params_info, + const DataTypeVector& arg_types, + const std::vector<InferredShape>& arg_shapes, + const DataTypeVector& retval_types, + const std::vector<InferredShape>& retval_shapes, const Graph& graph, + const Node* replicate_node, FunctionLibraryRuntime* flr, + std::vector<::xla::OpSharding>* arg_sharding, + std::vector<bool>* arg_fast_mem, + std::vector<::xla::OpSharding>* retval_sharding); + + // Computes a fingerprint of the contents of `library`. + static Status FingerprintFunctionLibrary( + const FunctionLibraryDefinition& library, uint64* fingerprint); + + // Populates `*variables` with the "variables" inputs to `index`-th output of + // `node`. + struct VariableInput { + Node* node; + int index; + + // Type of the variable's value. Note that this is different to the type of + // the output of 'variable', which is always DT_RESOURCE. + DataType dtype; + }; + static Status FindVariableInputs(const Node& node, + const NameRangeMap& input_range_map, + std::vector<VariableInput>* variables); + + // Populates '*guaranteed_constants' with the "guaranteed_constants" inputs + // to 'node'. + static Status FindGuaranteedConstantInputs( + const Node& node, const NameRangeMap& input_range_map, + std::vector<Node*>* guaranteed_constants); + + // Builds Shape nodes that compute the shapes of arguments whose shapes are + // not statically known. + static Status BuildDynamicShapeNodes( + const Node& replicate_node, const std::vector<InferredShape>& arg_shapes, + const ParameterInfo& params_info, + const std::vector<Node*>& variable_reads, Graph* graph, + std::vector<Node*>* dynamic_shape_nodes); + + // Builds a TPUCompile node that compiles the computation in + // `function_names`. calls `nodes`. + // TODO(b/33943292): at present, for model parallelism with Send/Recv to work + // the `nodes` must correspond to the computations assigned to TPU:0, + // TPU:1, ... in order since XLA hard-codes the chip IDs in the generated + // executables. + static Status BuildCompileNode( + const Node* replicate_node, const NameAttrList& function, + uint64 library_fingerprint, const ParameterInfo& params_info, + const std::vector<InferredShape>& arg_shapes, + const DataTypeVector& arg_types, + const std::vector<Node*>& guaranteed_constant_nodes, + const string& session_handle, + const std::vector<::xla::OpSharding>& arg_sharding, + const std::vector<bool>& arg_fast_mem, + const std::vector<::xla::OpSharding>& retval_sharding, + int num_cores_per_replica, const string& compile_device, + const xla::DeviceAssignment* xla_device_assignment, + const std::vector<Node*>& dynamic_shape_nodes, Graph* graph, + Node** compile_node, int64 autotuner_thresh); + + // Builds a TPUCompileSucceededAssert node that verifies that compilation + // succeeded and replaces the TPUCompilationStatus node in the graph. + static Status BuildCompilationStatusReturnNodes( + Node* replicate_node, Node* compile_node, + Node** control_after_compilation, Graph* graph); + + // Builds ReadVariableOp nodes that read `variables`, with a control + // edges that ensure they happen after `control_predecessor`. + static Status BuildVariableReads(absl::Span<const VariableInput> variables, + Node* control_predecessor, Graph* graph, + std::vector<Node*>* variable_reads); + + // Returns true if graph or functions contain resource write op, otherwise + // return false. + // TODO(b/137048563): Recognize unused resource rewrite op. + static bool ContainsResourceWriteOp(const Graph& graph, + const FunctionLibraryDefinition& fld); + // Struct that describes a variable value to be written back from TPUExecute. + struct VariableWrite { + // A node:output pair containing a boolean tensor that determines whether + // the value should be written back. + Node* predicate; + int predicate_output; + + // A node:output pair containing the value to be written back. + Node* value; + int value_output; + }; + + // Builds AssignVariableOp nodes that write `variables` with the values from + // `variable_writes`, with control edges that ensure the writes happen before + // `control_successor`. + static Status BuildVariableWrites( + absl::Span<const VariableInput> variables, Node* control_successor, + absl::Span<const VariableWrite> variable_writes, Graph* graph); + + // Builds TPUExecute operators assigned to each TPU device + // involved in the computation. + // Arguments: + // * `params_info` is the structure containing the information about the + // TPUReplicate node inputs and outputs. + // * `num_tasks` is the number of TensorFlow tasks in the slice. + // * `num_cores_per_replica` is the number of cores which are dedicated to + // each replica. + // * `replicate_node` is the original TPUReplicate node. + // * `arg_types` are the types of the arguments to the computation function + // passed as argument to TPUReplicate, including per-replica, + // broadcast, and variable arguments. + // * `arg_shapes` are the corresponding shapes (and handle types/shapes, if + // applicable). + // * `arg_shardings` and `retval_shardings` are mappings from + // arguments/return indices to shardings, as returned by + // `AssignArgsAndRetvalsToCores`. + // * `pod_devices` lists the devices to assign to each core of each replica. + // * `variable_reads` is a vectors of ReadVariableOp operators, one for each + // variable argument to the computation. + // * The execute operators will have a control edge from + // `control_predecessor` and another control edge to `control_successor`. + // Populates '*variable_writes' with information about variable values to + // write back. + static Status BuildExecuteNodes( + const ParameterInfo& params_info, int num_tasks, + int num_cores_per_replica, const Node& replicate_node, + const DataTypeVector& arg_types, + const std::vector<InferredShape>& arg_shapes, + const DataTypeVector& retval_types, + const std::vector<::xla::OpSharding>& arg_shardings, + const std::vector<::xla::OpSharding>& retval_shardings, + const std::vector<std::vector<string>>& tpu_device_names, + Node* compile_node, const std::vector<Node*>& variable_reads, + Node* control_predecessor, Node* control_successor, + std::vector<VariableWrite>* variable_writes, Graph* graph); + + // Connects the compile node to all the host transfer nodes, and removes the + // key placeholder node that was previously standing in for it. + // Arguments: + // * `compile_node` is the TPUCompile node that has been added to the graph. + // * `key_placeholder_node` is the placeholder node to send the key to all the + // host + // * transfer nodes in the original graph. + // * `graph` is the graph being rewritten. + static Status ConnectHostComputeNodes(Node* compile_node, + Node* key_placeholder_node, + Graph* graph); + + // Map from a Node in an outside_compilation cluster in the original graph to + // the list of Nodes, one for each replica, that it is expanded into during + // replication. + typedef absl::node_hash_map<Node*, std::vector<Node*>> NodeToNodeReplicasMap; + + // Map from the name of an outside_compilation cluster to the model-parallel + // core index that the HostCompute Op should be placed on in that cluster. + typedef std::map<string, int> HostComputeCoreMap; + + // Map from the name of an outside_compilation cluster to the list of Nodes + // that should run on the host for that cluster. + typedef std::map<string, std::vector<Node*>> OutsideCompilationNodeMap; + + // Copies the outside_compilation nodes in a cluster to create replica + // replica_index. + static Status CopyOutsideCompilationNodes( + int replica_index, const std::vector<Node*>& outside_compilation_nodes, + const DeviceNameUtils::ParsedName& tpu_device, + const DeviceNameUtils::ParsedName& partial_device, + NodeToNodeReplicasMap* node_images, Graph* graph); + + // Replicates all the nodes in outside_compilation clusters in a compiled + // computation. + static Status ReplicateOutsideCompilationNodes( + const std::vector<std::vector<string>>& tf_device_assignment, + const HostComputeCoreMap& host_compute_core, + const OutsideCompilationNodeMap& outside_compilation_nodes, + NodeToNodeReplicasMap* node_images, Graph* graph); + + // Lifts the edges between original outside_compilation nodes in a cluster + // onto their replicas. + static Status CopyOutsideCompilationEdges( + const std::vector<Node*>& outside_compilation_nodes, + const NodeToNodeReplicasMap& node_images, + const std::unordered_map<string, Node*> outside_compilation_inputs, + Graph* graph); + + // Lifts all the edges in outside_compilation clusters in a compiled + // computation to their replicas. + static Status ReplicateOutsideCompilationEdges( + const OutsideCompilationNodeMap& outside_compilation_nodes, + const NodeToNodeReplicasMap& node_images, + const std::unordered_map<string, Node*> outside_compilation_inputs, + Graph* graph); + + // Removes all the original outside_compilation nodes from the graph, + // following replication. + static Status RemoveOutsideCompilationNodes( + const NodeToNodeReplicasMap& node_images, Graph* graph); + + // Lowers outside compilation functional nodes (If/While/function call). + // Otherwise, when we have multiple workers, device placer will not be able to + // place nodes if outside compilation has DT_RESOURCE inputs (e.g. a + // DT_RESOURCE input fed into multiple While nodes on different devices). + static Status LowerOutsideCompilationFunctionalNodes( + Graph* g, const FunctionLibraryDefinition& flib_def, + const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping); + + // Parses the 'host_compute_core' attribute on replicate_node to get the + // replicated core id of each outside_compilation cluster. + static Status ParseHostComputeCores( + const Node& replicate_node, + const OutsideCompilationNodeMap& outside_compilation_nodes, + HostComputeCoreMap* host_compute_core); + + // Gets the physical topology information about the TPU system. + static Status GetDeviceTopology( + const DeviceSet& device_set, const Node& replicate_node, + int* num_replicas, int* num_cores_per_replica, int* num_tasks, + std::vector<std::vector<string>>* tf_device_assignment, + std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment, + string* tpu_compilation_device); + + // Gets the types of args, retvals, and parameters. + static Status GetIOTypes( + int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr, + Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function, + std::unique_ptr<Graph>* computation, DataTypeVector* arg_types, + DataTypeVector* retval_types, ParameterInfo* params_info); + + // Find known constants and deals with variable reads. + static Status DealWithConstantsAndVariables( + const Node& replicate_node, const NameRangeMap& input_name_map, + Graph* graph, Node* host_transfer_sequencer, Node* control_before, + Node* control_after, absl::Span<const VariableInput> variable_nodes, + std::vector<Node*>* guaranteed_constant_nodes, + std::vector<Node*>* variable_reads); + + // Adds NoOp nodes for sequencing computation and variable reads/writes. + static Status BuildSequencingNodes(const string& tpu_compilation_device, + const Node& replicate_node, Graph* graph, + Node** host_transfer_sequencer, + Node** control_before, + Node** control_after); + + // Performs the pass's rewrite on a TPUReplicate node `node`. + static Status RewriteTPUReplicateNode( + const string& session_handle, const DeviceSet& device_set, + Node* replicate_node, FunctionLibraryDefinition* flib_def, + FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node, + const OutsideCompilationNodeMap& outside_compilation_nodes, + const std::vector<Node*>& head_tail_outside_compilation_nodes, + NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph, + const GraphShapeInfo& shape_info, + TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping, + int64 autotuner_thresh); + + // Performs host training loop optimization. For example, when TPUExecute + // node is inside a while loop, then model weight variables can be sharded + // in XLA preferred layout and then unsharded only at the very last iteration + // to reduce the number of all_gather. + static Status PerformHostTrainingLoopOptimization( + Graph* graph, FunctionLibraryDefinition* flib_def, + FunctionLibraryRuntime* flr); + + // Heuristically place some nodes with unassigned devices on TPUs for + // performance reasons. + static Status PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph* graph); + + // Updates the head and tail outside compiled nodes so that nodes have the + // correct device and removes the replication and outside compilation + // attributes so that these nodes do not trigger further graph optimization + // passes. + static Status UpdateHeadTailOutsideCompilation( + const std::vector<std::vector<string>>& tf_device_assignment, + const std::vector<Node*>& head_tail_outside_compilation_nodes); + + private: + static bool distribute_vars_; + static bool replicate_inputs_outputs_by_default_for_xla_spmd_; + static bool enable_cross_replica_sharding_mirrored_variables_; + static bool enable_automatic_model_parallelism_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.cc new file mode 100644 index 00000000000..18b158c0335 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.cc @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h" + +#include <limits> + +#include "absl/random/random.h" + +namespace tensorflow { +namespace { + +static int64 overridden_node_id = -1; + +} // namespace + +namespace internal { + +void OverrideNodeIdForTesting(const int64 node_id) { + overridden_node_id = node_id; +} + +uint64 GetNodeId() { + if (overridden_node_id > -1) { + return overridden_node_id; + } else { + return absl::Uniform(absl::SharedBitGen(), uint64{0}, + std::numeric_limits<uint64>::max()); + } +} + +} // namespace internal +} // namespace tensorflow diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h new file mode 100644 index 00000000000..ce80249c30f --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_ + +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// Implementation details of distributed_tpu_rewrite_pass.cc, please DO NOT +// depend on these. +namespace internal { + +// When set to a value >= 0, overrides the node_id. Used for getting +// deterministic node_ids during testing. +void OverrideNodeIdForTesting(int64 node_id); + +// Retrieves the node id, used to make some node names unique in the rewrite +// pass. +uint64 GetNodeId(); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_ diff --git a/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc new file mode 100644 index 00000000000..fad8e22399c --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.cc @@ -0,0 +1,629 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h" + +#include <deque> +#include <map> +#include <unordered_map> + +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_set.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h" + +namespace tensorflow { +namespace tpu { + +namespace { + +constexpr char kDefaultShardingValue[] = ""; + +const Edge* FindEdgeConnecting(const Node* src, const Node* dst) { + for (const auto e : src->out_edges()) { + if (e->dst()->name() == dst->name()) return &(*e); + } + return nullptr; +} + +// Contains TPUExecute node and its DT_RESOURCE input nodes that +// correspond to model weights. +struct ExecuteNodeInfo { + Node* execute_node; + std::vector<const Edge*> var_inputs; +}; + +// Returns whether `node` is in `execute_nodes` or `(identity) -> execute`. +bool IsExecuteNodeOrIdentityToExecuteNode( + const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT + const absl::flat_hash_set<Node*>& execute_nodes, Node* node) { + if (execute_nodes.find(node) != execute_nodes.end()) return true; + if (loop_nodes.find(node) == loop_nodes.end()) return false; + if (node->IsNextIteration()) return true; + if (!node->IsIdentity()) return false; + + for (const Edge* e : node->out_edges()) { + if (e->IsControlEdge()) continue; + + Node* node = e->dst(); + if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes, + node)) { + return false; + } + } + + return true; +} + +// From input node to the TPUExecute op, finds the corresponding Enter node +// by searching/traversing nodes in below pattern of nodes: +// Enter ----> (identity) ---> While body input +// Returns nullptr if the Enter node is not found. +xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) { + Node* node = input_node; + while (node->IsIdentity()) { + TF_RETURN_IF_ERROR(node->input_node(0, &node)); + } + + if (node->IsEnter()) { + return node; + } + return nullptr; +} + +xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop( + const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT + const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) { + for (const Edge* output_edge : enter_node->out_edges()) { + Node* output_node = output_edge->dst(); + if (output_edge->IsControlEdge() || output_node->IsExit()) continue; + + // If output node is not execute node, it must be output node + // to the while loop body. + if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes, + output_node)) { + return false; + } + } + return true; +} + +// Given a TPUCompile node, find all TPUExecute nodes that executes the compiled +// program and its model weight variable inputs as well. +// TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata` +// if new reshard ops are added. +Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph, + const std::unordered_set<Node*>& loop_nodes, // NOLINT + std::vector<ExecuteNodeInfo>* execute_node_info, + TPUCompileMetadataProto* new_metadata) { + string metadata_string; + TF_RETURN_IF_ERROR( + GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string)); + new_metadata->ParsePartialFromString(metadata_string); + if (new_metadata->num_cores_per_replica() != 1) { + // We do not support model parallelism yet. + return Status::OK(); + } + + execute_node_info->clear(); + for (Node* node : compile_node->out_nodes()) { + if (node->type_string() == "TPUExecute") { + execute_node_info->push_back({node}); + } + } + if (execute_node_info->empty()) { + return Status::OK(); + } + TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas()) + << "Number of replicas does not equal number of execute nodes: " + << new_metadata->num_replicas() << " vs " << execute_node_info->size(); + DataTypeVector arg_types; + TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(), + "Targs", &arg_types)); + for (int64 i = 0; i < arg_types.size(); ++i) { + if (arg_types[i] != DT_RESOURCE) { + continue; + } + const auto sharding_config = new_metadata->args(i).enable_xla_sharding(); + if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE && + sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) { + continue; + } + std::vector<const Edge*> edges(execute_node_info->size()); + bool is_supported = true; + std::unordered_map<Node*, absl::flat_hash_set<Node*>> + enter_to_execute_nodes; + for (int64 j = 0; j < edges.size(); ++j) { + auto execute = (*execute_node_info)[j].execute_node; + TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j])); + TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) == + arg_types[i]) + << "Execute op has an unexpected input type."; + // Traverse backwards to find the Enter node from which the input is + // passed. + // This makes sure that we are checking the usages of all potential + // aliases of the input node as well. + TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput( + edges[j]->src())); + if (enter_node == nullptr) { + is_supported = false; + enter_to_execute_nodes.clear(); + break; + } + enter_to_execute_nodes[enter_node].insert(edges[j]->dst()); + } + + for (const auto& it : enter_to_execute_nodes) { + // Size of execute nodes should be either 1 (per-replica variables) or + // num_replicas (distributed variables). + if ((it.second.size() != 1) && + (it.second.size() != new_metadata->num_replicas())) { + is_supported = false; + break; + } + TF_ASSIGN_OR_RETURN(bool no_other_use, + ResourceOnlyUsedForTPUExecuteInLoop( + graph, loop_nodes, it.first, it.second)); + if (!no_other_use) { + is_supported = false; + break; + } + } + + // Add the variable input edges only when they are supported for all + // executes. + if (is_supported) { + for (int64 j = 0; j < edges.size(); ++j) { + (*execute_node_info)[j].var_inputs.push_back(edges[j]); + } + new_metadata->mutable_args(i)->set_enable_xla_sharding( + TPUCompileMetadataProto::Arg::ALLOWED); + } + } + + int64 total = 0; + for (const auto& a : new_metadata->args()) { + if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) { + total++; + } + } + TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size()) + << " total " << total << " var_inputs " + << (*execute_node_info)[0].var_inputs.size(); + if (total == 0) { + // We don't need to process anything if no input is added. + execute_node_info->clear(); + } + return Status::OK(); +} + +bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; } + +void FindTPUCompileNodes( + const std::string* current_function_name, + const AttrValueMap* current_function_attr, + const std::unordered_map<string, WhileLoopFrame>& frames, + std::vector<HostTrainingLoopInfo>* host_training_loops_info) { + // Adds frames with no children (i.e., the innermost frames) to a worklist. + std::deque<const WhileLoopFrame*> worklist; + + for (auto& frame : frames) { + if (frame.second.num_children == 0) { + worklist.push_back(&frame.second); + } + } + + // Check TPUCompile node from the innermost while loop to the outermost + // while loop. + while (!worklist.empty()) { + const WhileLoopFrame* frame = worklist.front(); + worklist.pop_front(); + + for (const auto& n : frame->nodes) { + if (!IsTPUCompileOp(*n)) continue; + + HostTrainingLoopInfo host_training_loop_info; + host_training_loop_info.compile_node_name = n->name(); + host_training_loop_info.loop_cond_node_name = frame->loop_cond->name(); + host_training_loop_info.while_loop_name = frame->name; + + for (const auto arg : frame->args) { + LoopArgInfo arg_info; + arg_info.enter_node_name = arg.enter->name(); + if (arg.exit) arg_info.exit_node_name = arg.exit->name(); + + host_training_loop_info.loop_arguments.push_back(std::move(arg_info)); + } + host_training_loop_info.loop_nodes = frame->nodes; + + if (current_function_name) { + host_training_loop_info.encapsulating_function_name = + *current_function_name; + } + if (current_function_attr) { + host_training_loop_info.encapsulating_function_attrs = + *current_function_attr; + } + + host_training_loops_info->emplace_back( + std::move(host_training_loop_info)); + } + + // If the parent has no remaining children, add it to the worklist. + --frame->parent->num_children; + if (frame->parent->num_children == 0) { + worklist.push_back(frame->parent); + } + } +} + +// From while loop cond node, finds all loop exit nodes by searching/traversing +// nodes in below pattern of nodes: +// LoopCond -----> Switch -----> Exit +std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) { + std::vector<Node*> loop_exit_nodes; + for (const auto e_cond : loop_cond.out_edges()) { + if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue; + auto switch_node = e_cond->dst(); + + for (const auto e_switch : switch_node->out_edges()) { + if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue; + + loop_exit_nodes.push_back(e_switch->dst()); + } + } + return loop_exit_nodes; +} + +// Find any one of switch nodes in the while loop by traversing the graph +// from while loop condition node. +xla::StatusOr<Node*> GetLoopSwitchNode(const Node& loop_cond_node) { + Node* loop_switch_node; + for (auto n : loop_cond_node.out_nodes()) { + if (n->IsSwitch()) { + loop_switch_node = n; + break; + } + } + + TF_RET_CHECK(loop_switch_node->IsSwitch()) + << "Unable to find any switch nodes."; + return loop_switch_node; +} + +// Returns or creates a node in that is executed before each loop iteration +// in the while loop. +Status GetOrCreateBeforeEachIterationNode(Graph* graph, Node* loop_switch_node, + Node** node_out) { + // If while loop switch node already has a outgoing data to true brach + // of the switch op, then reuse that node. + for (const auto out_edge : loop_switch_node->out_edges()) { + if (out_edge->src_output() == 1) { + *node_out = out_edge->dst(); + return Status::OK(); + } + } + + // Create Identity node that represents execution at every loop iteration. + NodeDef at_loop_iteration_nodedef; + at_loop_iteration_nodedef.set_op("Identity"); + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype)); + + AddNodeAttr("T", dtype, &at_loop_iteration_nodedef); + at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat( + "TPUVariableReshard/before_iteration", "/_", internal::GetNodeId()))); + + Status status; + Node* at_loop_iteration_node = + graph->AddNode(at_loop_iteration_nodedef, &status); + TF_RETURN_IF_ERROR(status); + + graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0); + *node_out = at_loop_iteration_node; + return Status::OK(); +} + +// Injects NoOp node in that is executed after the very last iteration +// of the while loop but before the while loop exit node. +Status AddNoOpAfterLastIteration(Graph* graph, Node* loop_switch_node, + Node** node_out) { + // Find the exit node from loop switch node. + Node* exit_node; + for (const auto out_node : loop_switch_node->out_nodes()) { + if (out_node->IsExit()) { + exit_node = out_node; + break; + } + } + + TF_RET_CHECK(exit_node != nullptr) + << "Cannot find exit node connected to switch node :" + << loop_switch_node->name(); + + // Create NoOp that represents execution at the end of while loop + // last iteration. + NodeDef after_last_loop_iteration; + after_last_loop_iteration.set_op("Identity"); + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype)); + + AddNodeAttr("T", dtype, &after_last_loop_iteration); + after_last_loop_iteration.set_name(graph->NewName(strings::StrCat( + "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId()))); + + Status status; + Node* after_last_iteration_node = + graph->AddNode(after_last_loop_iteration, &status); + TF_RETURN_IF_ERROR(status); + + // Newly created node must be executed once after last iteration of the while + // loop and before while loop exits. + graph->AddEdge(loop_switch_node, 0, after_last_iteration_node, 0); + graph->AddControlEdge(after_last_iteration_node, exit_node); + *node_out = after_last_iteration_node; + return Status::OK(); +} + +} // namespace + +Status DetectHostTrainingLoop( + const std::string* current_function_name, + const AttrValueMap* current_function_attr, + const FunctionLibraryDefinition* library, Graph* graph, + FunctionLibraryRuntime* flr, + std::vector<HostTrainingLoopInfo>* host_training_loops_info) { + std::vector<AssociatedFunctionInfo> associated_function_list; + for (const auto* n : graph->nodes()) { + const auto associated_functions = GetAssociatedFunctions(*n, library); + if (associated_functions.empty()) continue; + + associated_function_list.insert(associated_function_list.end(), + associated_functions.begin(), + associated_functions.end()); + } + + Status ret_status = Status::OK(); + for (const auto& function : associated_function_list) { + if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue; + + // Convert the function to Graph. + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(), + AttrSlice(&function.attrs()), &handle)); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* body = flr->GetFunctionBody(handle); + Graph* function_graph = body->graph; + TF_RETURN_IF_ERROR(DetectHostTrainingLoop( + &function.func_name(), &function.attrs(), library, function_graph, flr, + host_training_loops_info)); + } + + // BuildControlFlowInfo() requires that the graph's source node is connected + // to all source nodes in the graph. Many graphs violate this invariant. + // As so, add edges to source/sink nodes so that this invariant is kept. + FixupSourceAndSinkEdges(graph); + std::vector<ControlFlowInfo> cf_info; + TF_RETURN_IF_ERROR( + BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr)); + + std::unordered_map<string, WhileLoopFrame> frames; + TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames)); + FindTPUCompileNodes(current_function_name, current_function_attr, frames, + host_training_loops_info); + return ret_status; +} + +Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) { + const auto& compile_node_name = host_loop_info.compile_node_name; + const auto node_name_map = graph->BuildNodeNameIndex(); + const auto node_it = node_name_map.find(compile_node_name); + TF_RET_CHECK(node_it != node_name_map.end()) + << "Unable to find compile node : " << compile_node_name; + + const auto compile_node = node_it->second; + std::vector<ExecuteNodeInfo> execute_nodes_info; + + Status status; + TPUCompileMetadataProto metadata; + status = + ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes, + &execute_nodes_info, &metadata); + if (!status.ok()) { + LOG(ERROR) << "Encountered error when trying to extract execute nodes, " + "skipping host loop optimization. Status: " + << status.ToString(); + return Status::OK(); + } + + if (execute_nodes_info.empty()) { + return Status::OK(); + } + + // Update the TPUCompileMetadata such that sharding config of the + // sharded resource variable inputs is set to ALLOWED instead of + // TENTATIVE. + string new_metadata_string; + metadata.SerializeToString(&new_metadata_string); + compile_node->ClearAttr("metadata"); + compile_node->AddAttr("metadata", new_metadata_string); + + // Unsharding of the model weight variables must happen only at the very + // last loop iteration. As so, add while loop condition predicate as an + // input to the sharding switch node. If loop condition is true, we do not + // unshard. + const auto& cond_node_name = host_loop_info.loop_cond_node_name; + auto loop_cond_node_it = node_name_map.find(cond_node_name); + TF_RET_CHECK(loop_cond_node_it != node_name_map.end()) + << "Cannot find loop condition node : " << cond_node_name; + auto* loop_condition_node = loop_cond_node_it->second; + + // In order to make sure that shard/unshard operations are invoked + // at the start of every loop body and at the end of last iteration + // of the loop, respectively, traverse the graph and find a switch node + // of the host training loop. + TF_ASSIGN_OR_RETURN(Node * switch_node, + GetLoopSwitchNode(*loop_condition_node)); + + Node* after_last_iteration_node; + TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(graph, switch_node, + &after_last_iteration_node)); + + Node* before_loop_iteration_node; + TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode( + graph, switch_node, &before_loop_iteration_node)); + + // Create const op that represents default sharding value + // (i.e. no-op sharding). + NodeDef default_sharding; + default_sharding.set_op("Const"); + default_sharding.set_name(graph->NewName(strings::StrCat( + "TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId()))); + AddNodeAttr("dtype", DT_STRING, &default_sharding); + + Tensor t(DT_STRING, {2}); + t.vec<tstring>()(0) = kDefaultShardingValue; + t.vec<tstring>()(1) = kDefaultShardingValue; + t.AsProtoTensorContent( + (*default_sharding.mutable_attr())["value"].mutable_tensor()); + + Node* default_sharding_node = graph->AddNode(default_sharding, &status); + TF_RETURN_IF_ERROR(status); + // Add control edge between loop condition to make sure that + // default_sharding_node node is inside the while loop frame. + graph->AddControlEdge(loop_condition_node, default_sharding_node); + + // Build a no-op node used to add control edges after unshard nodes. + NodeDef after_unshard; + after_unshard.set_op("NoOp"); + after_unshard.set_name(graph->NewName(strings::StrCat( + "TPUVariableReshard/last_iteration", "/_", internal::GetNodeId()))); + auto after_unshard_node = graph->AddNode(after_unshard, &status); + TF_RETURN_IF_ERROR(status); + + for (auto info : execute_nodes_info) { + auto execute_node = info.execute_node; + // Create Reshard op that optionally shards model weight variables + // prior to program execution. + NodeDef reshard_node_def; + reshard_node_def.set_name(graph->NewName(strings::StrCat( + "TPUVariableReshard/reshard", "/_", internal::GetNodeId()))); + reshard_node_def.set_op("TPUReshardVariables"); + AddNodeAttr("N", static_cast<int>(info.var_inputs.size()), + &reshard_node_def); + Node* reshard_op_node = graph->AddNode(reshard_node_def, &status); + if (!status.ok()) return status; + + reshard_op_node->set_assigned_device_name( + execute_node->assigned_device_name()); + + // Reshard op must execute at every loop iteration prior to + // TPUExecute node. + graph->AddControlEdge(before_loop_iteration_node, reshard_op_node); + graph->AddControlEdge(reshard_op_node, execute_node); + + for (int i = 0; i < info.var_inputs.size(); ++i) { + const auto variable_edge = info.var_inputs[i]; + graph->AddEdge(variable_edge->src(), variable_edge->src_output(), + reshard_op_node, i); + } + + const int new_key_input = info.var_inputs.size(); + // Add program input edge from the compiler(i.e. compilation key). + const auto compilation_key_edge = + FindEdgeConnecting(compile_node, execute_node); + graph->AddEdge(compile_node, compilation_key_edge->src_output(), + reshard_op_node, new_key_input); + + // Create VarHandleOp to store sharding state. Sharding state holds string + // compilation key that identifies whether the graph is re-compiled and the + // variables need to be sharded again. + NodeDef var_handle_def; + var_handle_def.set_op("VarHandleOp"); + var_handle_def.set_name(graph->NewName(strings::StrCat( + "TPUVariableReshard/reshard_state", "/_", internal::GetNodeId()))); + AddNodeAttr("dtype", DT_STRING, &var_handle_def); + AddNodeAttr("shape", TensorShape({}), &var_handle_def); + Node* var_handle_node = graph->AddNode(var_handle_def, &status); + if (!status.ok()) return status; + + // Add control edge between `var_handle_def` node and while loop + // loop condition so that `var_handle_def` is inside the same while loop + // frame. + // TODO(hongjunchoi): Consider adding control edge from another node--such + // as input control node. + graph->AddControlEdge(loop_condition_node, var_handle_node); + + // Connect data edge between var handle op and reshard op. + const int format_state_input = new_key_input + 1; + graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input); + + // Create Reshard op that represents unsharding after TPUExecute. + NodeDef unshard_node_def; + unshard_node_def.set_name(graph->NewName(strings::StrCat( + "TPUVariableReshard/unshard", "/_", internal::GetNodeId()))); + unshard_node_def.set_op("TPUReshardVariables"); + AddNodeAttr("N", static_cast<int>(info.var_inputs.size()), + &unshard_node_def); + Node* unshard_op_node = graph->AddNode(unshard_node_def, &status); + TF_RETURN_IF_ERROR(status); + + unshard_op_node->set_assigned_device_name( + execute_node->assigned_device_name()); + + for (int i = 0; i < info.var_inputs.size(); ++i) { + const auto variable_edge = info.var_inputs[i]; + // Connect model weight resource variables to unshard op. Since unshard op + // must be only invoked after the very last loop iteration, for each while + // loop inputs, we traverse backwards to find the switch node of the host + // training loop and connect `output_false` field of the switch node with + // unshard op. + TF_ASSIGN_OR_RETURN( + Node * enter_node, + FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src())); + graph->AddEdge(enter_node, 0, unshard_op_node, i); + } + + // Add control dependency before/after unshard node and the control nodes. + graph->AddControlEdge(after_last_iteration_node, unshard_op_node); + graph->AddControlEdge(unshard_op_node, after_unshard_node); + + graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input); + + // Add data edge from sharding state var handle op to unshard op. + graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input); + } + // Add control dependency from after_unshard_node to all exits nodes. This is + // to make sure that the unshard ops will be executed as long as any of the + // exits are used. + for (auto exit : FindLoopExitNodes(*loop_condition_node)) { + graph->AddControlEdge(after_unshard_node, exit); + } + return Status::OK(); +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h new file mode 100644 index 00000000000..822dc9edd51 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ + +#include <string> +#include <unordered_set> +#include <vector> + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace tpu { + +struct LoopArgInfo { + std::string enter_node_name; + // Exit nodes are optional for loop invariant while loop args. + absl::optional<std::string> exit_node_name; +}; + +struct HostTrainingLoopInfo { + // Name and attribute information about the function in which + // host training loop is included. If host training loop is not + // inside a function call, then `function_name` and `function_attrs` + // are nullopt. + absl::optional<std::string> encapsulating_function_name; + absl::optional<AttrValueMap> encapsulating_function_attrs; + + // TPU Compile node as within a host training loop. + std::string compile_node_name; + + // Name of the while loop in which TPU compile op is located. + std::string while_loop_name; + + // Name of the node that represents loop condition. + std::string loop_cond_node_name; + + // Exit and Enter node names for each loop arguments. + std::vector<LoopArgInfo> loop_arguments; + + std::unordered_set<Node*> loop_nodes; // NOLINT +}; + +// Walks through the `graph`, recursively if functional nodes exist, and +// identifies all host training loops. Host training loops are the inner +// most while loops that encapsulates TPUCompileOp node. This would be +// later used/analyzed to inroduce host loop specific optimizations such +// as adding sharded weight update. +Status DetectHostTrainingLoop( + const std::string* current_function_name, + const AttrValueMap* current_function_attr, + const FunctionLibraryDefinition* library, Graph* graph, + FunctionLibraryRuntime* flr, + std::vector<HostTrainingLoopInfo>* host_training_loops_info); + +// Injects VariableReshardOps to before and after TPUExecute op inside +// host training loop body. This effectively applies sharded weight update +// on model weight variables. +Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ diff --git a/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.cc b/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.cc new file mode 100644 index 00000000000..47187204f69 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.cc @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h" + +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" + +namespace tensorflow { + +IncompleteNodeDefBuilder::IncompleteNodeDefBuilder(const string& name, + const string& op, + const NodeDebugInfo& debug) { + nodedef_.set_name(name); + nodedef_.set_op(op); + MergeDebugInfo(debug, &nodedef_); +} + +IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::AddAttr( + const string& attr, const DataType& type) { + AddNodeAttr(attr, type, &nodedef_); + return *this; +} + +IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::AddAttr(const string& attr, + int val) { + AddNodeAttr(attr, val, &nodedef_); + return *this; +} + +IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::Device( + const string& device) { + nodedef_.set_device(device); + return *this; +} + +Status IncompleteNodeDefBuilder::Build(Graph* graph, Node** n) { + Status status; + *n = graph->AddNode(nodedef_, &status); + return status; +} + +IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Identity( + const string& name, const DataType& type, const NodeDebugInfo& debug) { + return IncompleteNodeDefBuilder(name, "Identity", debug).AddAttr("T", type); +} + +IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Merge( + const string& name, const DataType& type, const NodeDebugInfo& debug, + int n) { + return IncompleteNodeDefBuilder(name, "Merge", debug) + .AddAttr("T", type) + .AddAttr("N", n); +} + +IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Switch( + const string& name, const DataType& type, const NodeDebugInfo& debug) { + return IncompleteNodeDefBuilder(name, "Switch", debug).AddAttr("T", type); +} + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h b/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h new file mode 100644 index 00000000000..88e484f00cf --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_ + +#include <string> + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Convenience builder to build NodeDefs without specifying the inputs. This is +// similar to NodeDefBuilder except inputs are not specified. +// TODO(jpienaar): Clean up NodeDefBuilder and remove this class. +class IncompleteNodeDefBuilder { + public: + IncompleteNodeDefBuilder(const string& name, const string& op, + const NodeDebugInfo& debug); + + IncompleteNodeDefBuilder& AddAttr(const string& attr, const DataType& type); + IncompleteNodeDefBuilder& AddAttr(const string& attr, int val); + + IncompleteNodeDefBuilder& Device(const string& device); + + Status Build(Graph* graph, Node** n); + + static IncompleteNodeDefBuilder Identity(const string& name, + const DataType& type, + const NodeDebugInfo& debug); + static IncompleteNodeDefBuilder Merge(const string& name, + const DataType& type, + const NodeDebugInfo& debug, int n); + static IncompleteNodeDefBuilder Switch(const string& name, + const DataType& type, + const NodeDebugInfo& debug); + + private: + NodeDef nodedef_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_ diff --git a/tensorflow/core/tpu/graph_rewrite/tpu_rewrite_pass_registration.cc b/tensorflow/core/tpu/graph_rewrite/tpu_rewrite_pass_registration.cc index ef1831464e2..83a652d7aaa 100644 --- a/tensorflow/core/tpu/graph_rewrite/tpu_rewrite_pass_registration.cc +++ b/tensorflow/core/tpu/graph_rewrite/tpu_rewrite_pass_registration.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h" +#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h" #include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h" #include "tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h" @@ -30,8 +31,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 34, EncapsulateTPUComputationsPass); REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 39, ExtractOutsideCompilationPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 40, + DistributedTPURewritePass); REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 0, VariableMergerPass); - } // namespace } // namespace tensorflow diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc index 6c767d1d66e..120245e34b7 100644 --- a/tensorflow/stream_executor/multi_platform_manager.cc +++ b/tensorflow/stream_executor/multi_platform_manager.cc @@ -55,8 +55,8 @@ class MultiPlatformManagerImpl { TF_LOCKS_EXCLUDED(mu_); port::StatusOr<std::vector<Platform*>> PlatformsWithFilter( - const std::function<bool(const Platform*)>& filter) - TF_LOCKS_EXCLUDED(mu_); + const std::function<bool(const Platform*)>& filter, + bool initialize_platform) TF_LOCKS_EXCLUDED(mu_); using Listener = MultiPlatformManager::Listener; port::Status RegisterListener(std::unique_ptr<Listener> listener) @@ -188,7 +188,8 @@ port::Status MultiPlatformManagerImpl::RegisterListener( port::StatusOr<std::vector<Platform*>> MultiPlatformManagerImpl::PlatformsWithFilter( - const std::function<bool(const Platform*)>& filter) { + const std::function<bool(const Platform*)>& filter, + bool initialize_platform) { absl::MutexLock lock(&mu_); CHECK_EQ(id_map_.size(), name_map_.size()); std::vector<Platform*> platforms; @@ -196,7 +197,7 @@ MultiPlatformManagerImpl::PlatformsWithFilter( for (const auto& entry : id_map_) { Platform* platform = entry.second; if (filter(platform)) { - if (!platform->Initialized()) { + if (initialize_platform && !platform->Initialized()) { SE_RETURN_IF_ERROR(platform->Initialize({})); } platforms.push_back(platform); @@ -299,7 +300,14 @@ MultiPlatformManager::InitializePlatformWithId( /*static*/ port::StatusOr<std::vector<Platform*>> MultiPlatformManager::PlatformsWithFilter( const std::function<bool(const Platform*)>& filter) { - return Impl().PlatformsWithFilter(filter); + return PlatformsWithFilter(filter, /*initialize_platform=*/true); +} + +/*static*/ port::StatusOr<std::vector<Platform*>> +MultiPlatformManager::PlatformsWithFilter( + const std::function<bool(const Platform*)>& filter, + bool initialize_platform) { + return Impl().PlatformsWithFilter(filter, initialize_platform); } } // namespace stream_executor diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h index fbb6effdf83..4fa2d819520 100644 --- a/tensorflow/stream_executor/multi_platform_manager.h +++ b/tensorflow/stream_executor/multi_platform_manager.h @@ -130,6 +130,10 @@ class MultiPlatformManager { static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter( const std::function<bool(const Platform*)>& filter); + static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter( + const std::function<bool(const Platform*)>& filter, + bool initialize_platform); + // Although the MultiPlatformManager "owns" its platforms, it holds them as // undecorated pointers to prevent races during program exit (between this // object's data and the underlying platforms (e.g., CUDA, OpenCL). diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD index 7fa46ebd8d1..a8557aada48 100644 --- a/tensorflow/stream_executor/tpu/BUILD +++ b/tensorflow/stream_executor/tpu/BUILD @@ -331,6 +331,7 @@ cc_library( name = "tpu_topology_external", srcs = ["tpu_topology.cc"], hdrs = ["tpu_topology.h"], + visibility = ["//visibility:public"], deps = [ ":c_api_decl", "//tensorflow/core/platform:types", diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc index 7580e709bdf..fa9062c217c 100644 --- a/tensorflow/stream_executor/tpu/tpu_platform_interface.cc +++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.cc @@ -23,10 +23,11 @@ namespace tensorflow { namespace tpu { namespace { -TpuPlatformInterface* GetRegisteredPlatformStatic() { +TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform) { // Prefer TpuPlatform if it's registered. auto status_or_tpu_platform = - stream_executor::MultiPlatformManager::PlatformWithName("TPU"); + stream_executor::MultiPlatformManager::PlatformWithName( + "TPU", initialize_platform); if (status_or_tpu_platform.ok()) { return static_cast<TpuPlatformInterface*>( status_or_tpu_platform.ValueOrDie()); @@ -43,7 +44,8 @@ TpuPlatformInterface* GetRegisteredPlatformStatic() { [](const stream_executor::Platform* platform) { return dynamic_cast<const TpuPlatformInterface*>(platform) != nullptr; - }); + }, + initialize_platform); if (!status_or_other_tpu_platforms.ok()) { LOG(WARNING) << "Error when getting other TPU platforms: " << status_or_tpu_platform.status(); @@ -64,9 +66,24 @@ TpuPlatformInterface* GetRegisteredPlatformStatic() { /* static */ TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() { - // Use a local static variable to avoid data races during initialization. + return GetRegisteredPlatform(/*initialize_platform=*/true); +} + +/* static */ +TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform( + bool initialize_platform) { + static bool requested_initialize_platform = initialize_platform; static TpuPlatformInterface* tpu_registered_platform = - GetRegisteredPlatformStatic(); + GetRegisteredPlatformStatic(initialize_platform); + + if (!requested_initialize_platform && initialize_platform) { + // If the first time this function is called, we did not request + // initializing the platform, but the next caller wants the platform + // initialized, we will call GetRegisteredPlatformStatic again to initialize + // the platform. + tpu_registered_platform = GetRegisteredPlatformStatic(initialize_platform); + } + return tpu_registered_platform; } diff --git a/tensorflow/stream_executor/tpu/tpu_platform_interface.h b/tensorflow/stream_executor/tpu/tpu_platform_interface.h index da9e91ffc1c..889375245a8 100644 --- a/tensorflow/stream_executor/tpu/tpu_platform_interface.h +++ b/tensorflow/stream_executor/tpu/tpu_platform_interface.h @@ -33,6 +33,9 @@ class TpuPlatformInterface : public stream_executor::Platform { // is registered or an error occurred. static TpuPlatformInterface* GetRegisteredPlatform(); + // Option to not initialize a platform if not necessary. + static TpuPlatformInterface* GetRegisteredPlatform(bool initialize_platform); + virtual Status Reset() { return Reset(false); } virtual Status Reset(bool only_tear_down) = 0; diff --git a/tensorflow/stream_executor/tpu/tpu_topology.h b/tensorflow/stream_executor/tpu/tpu_topology.h index 48371b6e008..b49b1e24386 100644 --- a/tensorflow/stream_executor/tpu/tpu_topology.h +++ b/tensorflow/stream_executor/tpu/tpu_topology.h @@ -30,6 +30,7 @@ struct TpuChipCoordinatesExternal { class TpuCoreLocationExternal { public: + TpuCoreLocationExternal() : core_location_(nullptr) {} explicit TpuCoreLocationExternal(void* core_location) : core_location_(core_location) {} TpuChipCoordinatesExternal chip_coordinates() const;