Open source distributed_tpu_rewrite_pass.cc and associated helper methods

PiperOrigin-RevId: 322460893
Change-Id: I8ca6164e8c4ce2b6d6e79db66fbb028305634ca5
This commit is contained in:
Frank Chen 2020-07-21 15:58:05 -07:00 committed by TensorFlower Gardener
parent 318340f1cf
commit 145d21a90d
18 changed files with 5939 additions and 11 deletions

View File

@ -13,6 +13,7 @@ cc_library(
srcs = ["tpu_rewrite_pass_registration.cc"],
deps = [
":distributed_tpu_configuration_rewrite_pass",
":distributed_tpu_rewrite_pass",
":encapsulate_tpu_computations_pass",
":variable_merger_pass",
"//tensorflow/core:core_cpu",
@ -101,3 +102,120 @@ cc_library(
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "distributed_tpu_rewrite_pass_internal",
srcs = ["distributed_tpu_rewrite_pass_internal.cc"],
hdrs = ["distributed_tpu_rewrite_pass_internal.h"],
deps = [
"//tensorflow/core:framework",
"@com_google_absl//absl/random",
],
)
cc_library(
name = "distributed_tpu_rewrite_pass",
srcs = [
"distributed_tpu_rewrite_pass.cc",
],
hdrs = [
"distributed_tpu_rewrite_pass.h",
],
deps = [
":cond_builder",
":distributed_tpu_rewrite_helpers",
":distributed_tpu_rewrite_pass_internal",
":host_training_loop_optimization_util",
":incomplete_nodedef_builder",
"//tensorflow/compiler/jit:encapsulate_util",
"//tensorflow/compiler/jit:shape_inference",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:sharding_util",
"//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
"//tensorflow/core/common_runtime:function",
"//tensorflow/core/common_runtime:graph_constructor",
"//tensorflow/core/common_runtime:lower_function_call_op",
"//tensorflow/core/common_runtime:lower_functional_ops",
"//tensorflow/core/common_runtime:lower_if_op",
"//tensorflow/core/common_runtime:lower_while_op",
"//tensorflow/core/common_runtime:optimization_registry",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
"//tensorflow/core/tpu:tpu_compile_interface",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
"//tensorflow/stream_executor/tpu:tpu_topology_external",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "incomplete_nodedef_builder",
srcs = ["incomplete_nodedef_builder.cc"],
hdrs = ["incomplete_nodedef_builder.h"],
deps = [
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "cond_builder",
srcs = ["cond_builder.cc"],
hdrs = ["cond_builder.h"],
deps = [
":incomplete_nodedef_builder",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "host_training_loop_optimization_util",
srcs = [
"host_training_loop_optimization_util.cc",
],
hdrs = [
"host_training_loop_optimization_util.h",
],
visibility = ["//visibility:public"],
deps = [
":distributed_tpu_rewrite_pass_internal",
"//tensorflow/compiler/tf2xla:functionalize_control_flow_util",
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -0,0 +1,83 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
namespace tensorflow {
CondBuilder::CondBuilder(string name, string device, const NodeDebugInfo& debug,
Graph* graph)
: graph_(graph), name_(std::move(name)), device_(std::move(device)) {
auto new_name = [graph, this](string suffix) {
return graph->NewName(strings::StrCat(name_, "/", suffix));
};
TF_CHECK_OK(
IncompleteNodeDefBuilder::Identity(new_name("pred"), DT_BOOL, debug)
.Device(device_)
.Build(graph_, &pred_));
Node* switch_pred;
TF_CHECK_OK(
IncompleteNodeDefBuilder::Switch(new_name("switch_pred"), DT_BOOL, debug)
.Device(device_)
.Build(graph_, &switch_pred));
graph_->AddEdge(pred(), 0, switch_pred, 0);
graph_->AddEdge(pred(), 0, switch_pred, 1);
TF_CHECK_OK(
IncompleteNodeDefBuilder::Identity(new_name("switch_f"), DT_BOOL, debug)
.Device(device_)
.Build(graph_, &switch_f_));
TF_CHECK_OK(
IncompleteNodeDefBuilder::Identity(new_name("switch_t"), DT_BOOL, debug)
.Device(device_)
.Build(graph_, &switch_t_));
graph_->AddEdge(switch_pred, kElseBranch, switch_f_, 0);
graph_->AddEdge(switch_pred, kThenBranch, switch_t_, 0);
Node* merge_pred;
TF_CHECK_OK(IncompleteNodeDefBuilder::Merge(new_name("merge_pred"), DT_BOOL,
debug, /*n=*/2)
.Device(device_)
.Build(graph_, &merge_pred));
graph_->AddEdge(switch_f_, 0, merge_pred, kElseBranch);
graph_->AddEdge(switch_t_, 0, merge_pred, kThenBranch);
// Note: when additional return values are added then there should be a
// control dependency between those merge nodes and control_successor_ to
// ensure that it is control successor of conditional.
control_successor_ = merge_pred;
}
Node* CondBuilder::pred() { return pred_; }
Node* CondBuilder::switch_f() { return switch_f_; }
Node* CondBuilder::switch_t() { return switch_t_; }
Node* CondBuilder::control_successor() { return control_successor_; }
Status CondBuilder::AddInput(const string& input_name, const DataType& type,
const string& device, const NodeDebugInfo& debug,
Node** input) {
auto b = IncompleteNodeDefBuilder::Switch(
graph_->NewName(strings::StrCat(name_, "/", input_name)), type, debug);
TF_RETURN_IF_ERROR(b.Device(device).Build(graph_, input));
graph_->AddEdge(pred(), 0, *input, 1);
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
#include <string>
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Conditional builder.
// Convenience builder to make it easy to construct a conditional. E.g.,
// Node* pred = ...;
// CondBuilder cb("cond", g);
// auto switch_var = cb.AddInput("var", DT_RESOURCE);
// g->AddEdge(pred, 0, cb.pred(), 0);
// Will create the nodes of a conditional that takes as input a resource
// variable ("var") as input and that switches on pred.
//
// This currently only handles the case needed by distributed_tpu_rewrite_pass
// and is not completely general.
class CondBuilder {
public:
enum Branch { kElseBranch = 0, kThenBranch = 1 };
CondBuilder(string name, string device, const NodeDebugInfo& debug,
Graph* graph);
// Returns node corresponding to the predicate input.
Node* pred();
// Returns node corresponding to switch_f branch of predicate switch.
Node* switch_f();
// Returns node corresponding to switch_t branch of predicate switch.
Node* switch_t();
// Returns node corresponding to control successor.
Node* control_successor();
// Returns the Switch node to feed a value of the given type into the
// conditional.
Status AddInput(const string& input_name, const DataType& type,
const string& device, const NodeDebugInfo& debug,
Node** input);
private:
Node* control_successor_;
Node* switch_f_;
Node* switch_t_;
Node* pred_;
Graph* const graph_;
const string name_;
const string device_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,589 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Rewrites TPUReplicate nodes into replicated computations on TPU.
//
// To represent a distributed TPU computation, we use the
// TPUReplicate operator, that describes a subgraph (represented as a
// Tensorflow function) to replicate across a TPU pod.
//
// Model parallelism and data parallelism:
// ---------------------------------------
// We support two different kinds of parallelism on TPU:
// * data parallelism (replication), or parallelization across batches, and
// * model parallelism, or parallelization within a batch.
//
// The function passed to a TPUReplicate operator is replicated many
// times across a TPU pod (data parallelism). The `num_replicas` attribute
// controls how many replicas of the computation to create. Replicas are mostly
// independent; replicas can only communicate using the CrossReplicaSum
// operator, which is typically used to communicate gradients during training.
//
// Each replica may optionally use more than one TPU core (model
// parallelism). The `num_cores_per_replica` attribute controls how many cores
// there are per replica. For each core, there is a virtual TPU_REPLICATED_CORE
// device that is only valid within replicated TPU computations (e.g.,
// TPU_REPLICATED_CORE:0, TPU_REPLICATED_CORE:1, etc.); each TPU_REPLICATED_CORE
// device corresponds to one TPU core in every replica.
// Each replica has runs its own copy of the computation assigned to each
// TPU_REPLICATED_CORE device.
//
// The Python code is responsible for providing a device_assignment that
// describes how the replicated logical cores map to physical cores on the TPU
// topology.
//
// Inputs to TPUReplicate:
// ------------------------------
// The TPUReplicate operator takes three kinds of inputs, in the
// following order:
// * per-replica inputs. If there are three per-replica inputs (A, B, C) and two
// replicas, the first six arguments to TPUReplicate will be:
// A0 B0 C0 A1 B1 C1
// where Ai is the A input to the i-th replica.
// * distributed inputs. These inputs follow the per-replica inputs.
// If there are two distributed inputs (E, F) and two replicas, the following
// arguments to TPUReplicate will be: E F.
// But there is local E and F on each replica.
// * broadcast inputs. These inputs follow the distributed inputs. All
// replicas receive a copy of each of these inputs.
// * variables. Resource variables accessed by the computation follow the
// broadcast inputs.
//
// For example, for a computation with two replicas, three per-replica inputs
// (A, B, C), two distributed inputs(E, F), two broadcast inputs (X, Y), and two
// variables (V, W), the arguments to TPUReplicate will be:
// A0 B0 C0 A1 B1 C1 E F X Y V W
// and each replica will receive the following arguments:
// A B C E F X Y V W
//
// Distributed TPU compilation requires that the shapes of all operators
// be known statically at compilation time, before any nodes have executed.
// Shapes are determined using shape information emitted by InferShapes. It
// is not possible to replicate Tensorflow operators with unknown or dynamic
// shapes for TPU at present.
//
// Graph rewrite:
// --------------
// Compilation replaces TPUReplicate operators with:
// * a single TPUCompile node that compiles the computations,
// * one TPUExecute node for each TPU device in the system that
// executes the relevant computation,
// * one ReadVariableOp for each variable accessed by the replicated
// computation,
// * one AssignVariableOp for each variable accessed by the replicated
// computation. An assignment is built even if a variable is only read by the
// computation. We do not know which variables are written until we apply the
// XlaCompiler to the computation, but that does not happen until after the
// rewrite. Conservatively, we write back the values of all variables after
// the computation completes.
// TODO(phawkins): only write back variables that the computation may write.
// * one Shape node for each Tensor or Variable input to the computation whose
// shape is not statically known at rewrite time. The input shapes are fed
// to the TPUCompile node.
//
// To ensure that the reads and writes seem to happen at the right time in the
// graph execution, we add control edges from all predecessors of the original
// TPUReplicate operator to each of the ReadVariableOp operators.
// Similarly, we add control edges from all of the AssignVariableOp operators to
// all of the successors of the TPUReplicate operator.
//
// The TPUReplicate rewrite must run before placement, since resource
// variable inputs will have DT_RESOURCE, which cannot be sent across devices,
// leading to objections from the placer. The rewrite rewrites the resource
// accesses into explicit ReadVariableOp and AssignVariableOp operators that the
// placer is free to colocate with the variables.
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
#include <string>
#include <vector>
#include "absl/container/node_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/jit/shape_inference.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/stream_executor/tpu/tpu_topology.h"
namespace tensorflow {
// Replaces clusters assigned to TPU_SYSTEM devices with
// TPUCompile and TPUExecute nodes assigned to the corresponding
// TPU devices.
class DistributedTPURewritePass : public GraphOptimizationPass {
public:
static void SetDistributedTpuRewritePassOptions(
bool distribute_vars,
bool replicate_inputs_outputs_by_default_for_xla_spmd,
bool enable_cross_replica_sharding_mirrored_variables,
bool enable_automatic_model_parallelism);
Status Run(const GraphOptimizationPassOptions& options) override;
// The following methods are public only for the use of unit tests.
// See comment at the top of the file for how the inputs are ordered.
// Encapsulates the different TPU replicated node input and output
// information, and provide common APIs over them.
class ParameterInfo {
public:
ParameterInfo() {}
ParameterInfo(int64 num_replicas, int64 num_per_replica_args,
int64 num_distributed_args, int64 num_broadcast_args,
int64 num_variables, int64 num_guaranteed_constants,
int64 num_retvals_per_replica)
: num_replicas_(num_replicas),
num_per_replica_args_(num_per_replica_args),
num_distributed_args_(num_distributed_args),
num_broadcast_args_(num_broadcast_args),
num_variables_(num_variables),
num_guaranteed_constants_(num_guaranteed_constants),
num_retvals_per_replica_(num_retvals_per_replica) {}
int64 NumReplicas() const { return num_replicas_; }
int64 NumPerReplicaArgs() const { return num_per_replica_args_; }
int64 NumDistributedArgs() const { return num_distributed_args_; }
int64 NumBroadcastArgs() const { return num_broadcast_args_; }
int64 NumVariables() const { return num_variables_; }
int64 NumGuaranteedConstants() const { return num_guaranteed_constants_; }
int64 NumRetvalsPerReplica() const { return num_retvals_per_replica_; }
bool IsPerReplicaArg(int64 index) const {
return index < num_per_replica_args_;
}
bool IsDistributedArg(int64 index) const {
return index >= num_per_replica_args_ &&
index < (num_per_replica_args_ + num_distributed_args_);
}
bool IsBroadcastArg(int64 index) const {
return index >= num_per_replica_args_ &&
index < (num_per_replica_args_ + num_distributed_args_ +
num_broadcast_args_);
}
bool IsVariableArg(int64 index) const {
return index >= (num_per_replica_args_ + num_broadcast_args_) &&
index < (num_per_replica_args_ + num_distributed_args_ +
num_broadcast_args_ + num_variables_);
}
bool IsConstantArg(int64 index) const {
return index >= (num_per_replica_args_ + num_distributed_args_ +
num_broadcast_args_ + num_variables_) &&
index < (num_per_replica_args_ + num_distributed_args_ +
num_broadcast_args_ + num_variables_ +
num_guaranteed_constants_);
}
// Returns the number of inputs which has been received by the host.
int64 NumInputsFromHost() const {
return num_replicas_ * num_per_replica_args_ + num_distributed_args_ +
num_broadcast_args_ + num_variables_ + num_guaranteed_constants_;
}
// Returns the number of inputs which will be sent to each replica.
int64 NumInputsToEachReplica() const {
return num_per_replica_args_ + num_distributed_args_ +
num_broadcast_args_ + num_variables_ + num_guaranteed_constants_;
}
// Returns the total number of output values returned to the host (for all
// replicas).
int64 NumOutputsToHost() const {
return num_replicas_ * num_retvals_per_replica_;
}
// Returns the position of the first per-replica argument, within the set
// of all hosts arguments.
// Broadcast arguments follow the distributed arguments.
int64 FirstBroadcastArgFromHost() const {
return num_replicas_ * num_per_replica_args_ + num_distributed_args_;
}
// Indices of mirrored variables across replicas, which should be
// categorized as per_replica_args.
const std::set<int64>& mirrored_variable_indices() const {
return mirrored_variable_indices_;
}
std::set<int64>* mutable_mirrored_variable_indices() {
return &mirrored_variable_indices_;
}
private:
int64 num_replicas_ = 1;
int64 num_per_replica_args_ = 0;
int64 num_distributed_args_ = 0;
int64 num_broadcast_args_ = 0;
int64 num_variables_ = 0;
int64 num_guaranteed_constants_ = 0;
int64 num_retvals_per_replica_ = 0;
std::set<int64> mirrored_variable_indices_;
};
// Mapping from TPUReplicate cluster name to tpu device names. Value is a
// mapping from [replica][core] to a TF device name.
typedef absl::flat_hash_map<string, std::vector<std::vector<string>>>
TPUReplicateDeviceNamesMapping;
// Determines which devices to use to run the computation.
// Inputs:
// * num_tpus_per_task: the number of TPU devices attached to each task
// * tpu_devices: a [task][device] collection of TPU devices
// * num_replicas: the number of replicas requested
// * num_cores_per_replica: the number of cores in each computation instance
// * topology_attr: the topology TPUReplicate attribute
// * device_assignment_attr: the device_assignment TPUReplicate attribute
// Outputs:
// * tf_device_assignment: a mapping from [replica][core] to a TF device name
// * xla_device_assignment: a mapping from [replica][core] to a linearized TPU
// coordinate.
// TODO(phawkins): change tf_device_assignment to an xla::Array2D.
static Status BuildDeviceAssignment(
const tpu::TpuTopologyExternal& topology, int num_tpus_per_task,
const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
int num_cores_per_replica, const string& topology_attr,
absl::Span<const int> device_assignment_attr,
std::vector<std::vector<string>>* tf_device_assignment,
std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment);
// Returns the `computation` graph attached to TPUReplicate operator
// `node`. `flr` is a FunctionLibraryRuntime to use when
// instantiating the function body. Sets `*arg_types` and
// `*retval_types` to the argument/return types of the function.
static Status GetComputationForTPUReplicateOp(const NameAttrList& function,
FunctionLibraryRuntime* flr,
Graph* computation,
DataTypeVector* arg_types,
DataTypeVector* retval_types);
// Returns the shapes of the argument tensors and return values of the
// TPUReplicate operator `node` using the _output_shapes,
// _output_handle_shapes, and _output_handle_types annotations on the input
// nodes. Expects inputs in the following order (see comment at top of file):
// * num_replicas * num_per_replica_args per-replica inputs,
// * num_broadcast_args broadcast inputs,
// * num_variables variable inputs.
// Returns an error if the input shapes to `node` are not statically known.
// Also verifies that all replicas have identical input shapes for their
// per-replica inputs.
static Status GetArgAndRetvalShapes(
const GraphShapeInfo& shape_info, const Node& node,
const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
std::vector<InferredShape>* retval_shapes);
// Assigns arguments and return values to cores. The assignment is represented
// as an XLA op sharding, so that an argument can be replicated across cores.
// `arg_sharding` and `retval_sharding` are vectors of shardings indexed by
// argument/retval number.
// `arg_fast_mem` is vector of fast_mem indication which is indexed by
// argument number.
static Status AssignArgsAndRetvalsToCores(
int num_cores_per_replica, const ParameterInfo& params_info,
const DataTypeVector& arg_types,
const std::vector<InferredShape>& arg_shapes,
const DataTypeVector& retval_types,
const std::vector<InferredShape>& retval_shapes, const Graph& graph,
const Node* replicate_node, FunctionLibraryRuntime* flr,
std::vector<::xla::OpSharding>* arg_sharding,
std::vector<bool>* arg_fast_mem,
std::vector<::xla::OpSharding>* retval_sharding);
// Computes a fingerprint of the contents of `library`.
static Status FingerprintFunctionLibrary(
const FunctionLibraryDefinition& library, uint64* fingerprint);
// Populates `*variables` with the "variables" inputs to `index`-th output of
// `node`.
struct VariableInput {
Node* node;
int index;
// Type of the variable's value. Note that this is different to the type of
// the output of 'variable', which is always DT_RESOURCE.
DataType dtype;
};
static Status FindVariableInputs(const Node& node,
const NameRangeMap& input_range_map,
std::vector<VariableInput>* variables);
// Populates '*guaranteed_constants' with the "guaranteed_constants" inputs
// to 'node'.
static Status FindGuaranteedConstantInputs(
const Node& node, const NameRangeMap& input_range_map,
std::vector<Node*>* guaranteed_constants);
// Builds Shape nodes that compute the shapes of arguments whose shapes are
// not statically known.
static Status BuildDynamicShapeNodes(
const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
const ParameterInfo& params_info,
const std::vector<Node*>& variable_reads, Graph* graph,
std::vector<Node*>* dynamic_shape_nodes);
// Builds a TPUCompile node that compiles the computation in
// `function_names`. calls `nodes`.
// TODO(b/33943292): at present, for model parallelism with Send/Recv to work
// the `nodes` must correspond to the computations assigned to TPU:0,
// TPU:1, ... in order since XLA hard-codes the chip IDs in the generated
// executables.
static Status BuildCompileNode(
const Node* replicate_node, const NameAttrList& function,
uint64 library_fingerprint, const ParameterInfo& params_info,
const std::vector<InferredShape>& arg_shapes,
const DataTypeVector& arg_types,
const std::vector<Node*>& guaranteed_constant_nodes,
const string& session_handle,
const std::vector<::xla::OpSharding>& arg_sharding,
const std::vector<bool>& arg_fast_mem,
const std::vector<::xla::OpSharding>& retval_sharding,
int num_cores_per_replica, const string& compile_device,
const xla::DeviceAssignment* xla_device_assignment,
const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
Node** compile_node, int64 autotuner_thresh);
// Builds a TPUCompileSucceededAssert node that verifies that compilation
// succeeded and replaces the TPUCompilationStatus node in the graph.
static Status BuildCompilationStatusReturnNodes(
Node* replicate_node, Node* compile_node,
Node** control_after_compilation, Graph* graph);
// Builds ReadVariableOp nodes that read `variables`, with a control
// edges that ensure they happen after `control_predecessor`.
static Status BuildVariableReads(absl::Span<const VariableInput> variables,
Node* control_predecessor, Graph* graph,
std::vector<Node*>* variable_reads);
// Returns true if graph or functions contain resource write op, otherwise
// return false.
// TODO(b/137048563): Recognize unused resource rewrite op.
static bool ContainsResourceWriteOp(const Graph& graph,
const FunctionLibraryDefinition& fld);
// Struct that describes a variable value to be written back from TPUExecute.
struct VariableWrite {
// A node:output pair containing a boolean tensor that determines whether
// the value should be written back.
Node* predicate;
int predicate_output;
// A node:output pair containing the value to be written back.
Node* value;
int value_output;
};
// Builds AssignVariableOp nodes that write `variables` with the values from
// `variable_writes`, with control edges that ensure the writes happen before
// `control_successor`.
static Status BuildVariableWrites(
absl::Span<const VariableInput> variables, Node* control_successor,
absl::Span<const VariableWrite> variable_writes, Graph* graph);
// Builds TPUExecute operators assigned to each TPU device
// involved in the computation.
// Arguments:
// * `params_info` is the structure containing the information about the
// TPUReplicate node inputs and outputs.
// * `num_tasks` is the number of TensorFlow tasks in the slice.
// * `num_cores_per_replica` is the number of cores which are dedicated to
// each replica.
// * `replicate_node` is the original TPUReplicate node.
// * `arg_types` are the types of the arguments to the computation function
// passed as argument to TPUReplicate, including per-replica,
// broadcast, and variable arguments.
// * `arg_shapes` are the corresponding shapes (and handle types/shapes, if
// applicable).
// * `arg_shardings` and `retval_shardings` are mappings from
// arguments/return indices to shardings, as returned by
// `AssignArgsAndRetvalsToCores`.
// * `pod_devices` lists the devices to assign to each core of each replica.
// * `variable_reads` is a vectors of ReadVariableOp operators, one for each
// variable argument to the computation.
// * The execute operators will have a control edge from
// `control_predecessor` and another control edge to `control_successor`.
// Populates '*variable_writes' with information about variable values to
// write back.
static Status BuildExecuteNodes(
const ParameterInfo& params_info, int num_tasks,
int num_cores_per_replica, const Node& replicate_node,
const DataTypeVector& arg_types,
const std::vector<InferredShape>& arg_shapes,
const DataTypeVector& retval_types,
const std::vector<::xla::OpSharding>& arg_shardings,
const std::vector<::xla::OpSharding>& retval_shardings,
const std::vector<std::vector<string>>& tpu_device_names,
Node* compile_node, const std::vector<Node*>& variable_reads,
Node* control_predecessor, Node* control_successor,
std::vector<VariableWrite>* variable_writes, Graph* graph);
// Connects the compile node to all the host transfer nodes, and removes the
// key placeholder node that was previously standing in for it.
// Arguments:
// * `compile_node` is the TPUCompile node that has been added to the graph.
// * `key_placeholder_node` is the placeholder node to send the key to all the
// host
// * transfer nodes in the original graph.
// * `graph` is the graph being rewritten.
static Status ConnectHostComputeNodes(Node* compile_node,
Node* key_placeholder_node,
Graph* graph);
// Map from a Node in an outside_compilation cluster in the original graph to
// the list of Nodes, one for each replica, that it is expanded into during
// replication.
typedef absl::node_hash_map<Node*, std::vector<Node*>> NodeToNodeReplicasMap;
// Map from the name of an outside_compilation cluster to the model-parallel
// core index that the HostCompute Op should be placed on in that cluster.
typedef std::map<string, int> HostComputeCoreMap;
// Map from the name of an outside_compilation cluster to the list of Nodes
// that should run on the host for that cluster.
typedef std::map<string, std::vector<Node*>> OutsideCompilationNodeMap;
// Copies the outside_compilation nodes in a cluster to create replica
// replica_index.
static Status CopyOutsideCompilationNodes(
int replica_index, const std::vector<Node*>& outside_compilation_nodes,
const DeviceNameUtils::ParsedName& tpu_device,
const DeviceNameUtils::ParsedName& partial_device,
NodeToNodeReplicasMap* node_images, Graph* graph);
// Replicates all the nodes in outside_compilation clusters in a compiled
// computation.
static Status ReplicateOutsideCompilationNodes(
const std::vector<std::vector<string>>& tf_device_assignment,
const HostComputeCoreMap& host_compute_core,
const OutsideCompilationNodeMap& outside_compilation_nodes,
NodeToNodeReplicasMap* node_images, Graph* graph);
// Lifts the edges between original outside_compilation nodes in a cluster
// onto their replicas.
static Status CopyOutsideCompilationEdges(
const std::vector<Node*>& outside_compilation_nodes,
const NodeToNodeReplicasMap& node_images,
const std::unordered_map<string, Node*> outside_compilation_inputs,
Graph* graph);
// Lifts all the edges in outside_compilation clusters in a compiled
// computation to their replicas.
static Status ReplicateOutsideCompilationEdges(
const OutsideCompilationNodeMap& outside_compilation_nodes,
const NodeToNodeReplicasMap& node_images,
const std::unordered_map<string, Node*> outside_compilation_inputs,
Graph* graph);
// Removes all the original outside_compilation nodes from the graph,
// following replication.
static Status RemoveOutsideCompilationNodes(
const NodeToNodeReplicasMap& node_images, Graph* graph);
// Lowers outside compilation functional nodes (If/While/function call).
// Otherwise, when we have multiple workers, device placer will not be able to
// place nodes if outside compilation has DT_RESOURCE inputs (e.g. a
// DT_RESOURCE input fed into multiple While nodes on different devices).
static Status LowerOutsideCompilationFunctionalNodes(
Graph* g, const FunctionLibraryDefinition& flib_def,
const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping);
// Parses the 'host_compute_core' attribute on replicate_node to get the
// replicated core id of each outside_compilation cluster.
static Status ParseHostComputeCores(
const Node& replicate_node,
const OutsideCompilationNodeMap& outside_compilation_nodes,
HostComputeCoreMap* host_compute_core);
// Gets the physical topology information about the TPU system.
static Status GetDeviceTopology(
const DeviceSet& device_set, const Node& replicate_node,
int* num_replicas, int* num_cores_per_replica, int* num_tasks,
std::vector<std::vector<string>>* tf_device_assignment,
std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
string* tpu_compilation_device);
// Gets the types of args, retvals, and parameters.
static Status GetIOTypes(
int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
DataTypeVector* retval_types, ParameterInfo* params_info);
// Find known constants and deals with variable reads.
static Status DealWithConstantsAndVariables(
const Node& replicate_node, const NameRangeMap& input_name_map,
Graph* graph, Node* host_transfer_sequencer, Node* control_before,
Node* control_after, absl::Span<const VariableInput> variable_nodes,
std::vector<Node*>* guaranteed_constant_nodes,
std::vector<Node*>* variable_reads);
// Adds NoOp nodes for sequencing computation and variable reads/writes.
static Status BuildSequencingNodes(const string& tpu_compilation_device,
const Node& replicate_node, Graph* graph,
Node** host_transfer_sequencer,
Node** control_before,
Node** control_after);
// Performs the pass's rewrite on a TPUReplicate node `node`.
static Status RewriteTPUReplicateNode(
const string& session_handle, const DeviceSet& device_set,
Node* replicate_node, FunctionLibraryDefinition* flib_def,
FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
const OutsideCompilationNodeMap& outside_compilation_nodes,
const std::vector<Node*>& head_tail_outside_compilation_nodes,
NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
const GraphShapeInfo& shape_info,
TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
int64 autotuner_thresh);
// Performs host training loop optimization. For example, when TPUExecute
// node is inside a while loop, then model weight variables can be sharded
// in XLA preferred layout and then unsharded only at the very last iteration
// to reduce the number of all_gather.
static Status PerformHostTrainingLoopOptimization(
Graph* graph, FunctionLibraryDefinition* flib_def,
FunctionLibraryRuntime* flr);
// Heuristically place some nodes with unassigned devices on TPUs for
// performance reasons.
static Status PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph* graph);
// Updates the head and tail outside compiled nodes so that nodes have the
// correct device and removes the replication and outside compilation
// attributes so that these nodes do not trigger further graph optimization
// passes.
static Status UpdateHeadTailOutsideCompilation(
const std::vector<std::vector<string>>& tf_device_assignment,
const std::vector<Node*>& head_tail_outside_compilation_nodes);
private:
static bool distribute_vars_;
static bool replicate_inputs_outputs_by_default_for_xla_spmd_;
static bool enable_cross_replica_sharding_mirrored_variables_;
static bool enable_automatic_model_parallelism_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
#include <limits>
#include "absl/random/random.h"
namespace tensorflow {
namespace {
static int64 overridden_node_id = -1;
} // namespace
namespace internal {
void OverrideNodeIdForTesting(const int64 node_id) {
overridden_node_id = node_id;
}
uint64 GetNodeId() {
if (overridden_node_id > -1) {
return overridden_node_id;
} else {
return absl::Uniform(absl::SharedBitGen(), uint64{0},
std::numeric_limits<uint64>::max());
}
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
// Implementation details of distributed_tpu_rewrite_pass.cc, please DO NOT
// depend on these.
namespace internal {
// When set to a value >= 0, overrides the node_id. Used for getting
// deterministic node_ids during testing.
void OverrideNodeIdForTesting(int64 node_id);
// Retrieves the node id, used to make some node names unique in the rewrite
// pass.
uint64 GetNodeId();
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_

View File

@ -0,0 +1,629 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
#include <deque>
#include <map>
#include <unordered_map>
#include "absl/container/flat_hash_set.h"
#include "absl/container/node_hash_set.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
namespace tensorflow {
namespace tpu {
namespace {
constexpr char kDefaultShardingValue[] = "";
const Edge* FindEdgeConnecting(const Node* src, const Node* dst) {
for (const auto e : src->out_edges()) {
if (e->dst()->name() == dst->name()) return &(*e);
}
return nullptr;
}
// Contains TPUExecute node and its DT_RESOURCE input nodes that
// correspond to model weights.
struct ExecuteNodeInfo {
Node* execute_node;
std::vector<const Edge*> var_inputs;
};
// Returns whether `node` is in `execute_nodes` or `(identity) -> execute`.
bool IsExecuteNodeOrIdentityToExecuteNode(
const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
const absl::flat_hash_set<Node*>& execute_nodes, Node* node) {
if (execute_nodes.find(node) != execute_nodes.end()) return true;
if (loop_nodes.find(node) == loop_nodes.end()) return false;
if (node->IsNextIteration()) return true;
if (!node->IsIdentity()) return false;
for (const Edge* e : node->out_edges()) {
if (e->IsControlEdge()) continue;
Node* node = e->dst();
if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
node)) {
return false;
}
}
return true;
}
// From input node to the TPUExecute op, finds the corresponding Enter node
// by searching/traversing nodes in below pattern of nodes:
// Enter ----> (identity) ---> While body input
// Returns nullptr if the Enter node is not found.
xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) {
Node* node = input_node;
while (node->IsIdentity()) {
TF_RETURN_IF_ERROR(node->input_node(0, &node));
}
if (node->IsEnter()) {
return node;
}
return nullptr;
}
xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop(
const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) {
for (const Edge* output_edge : enter_node->out_edges()) {
Node* output_node = output_edge->dst();
if (output_edge->IsControlEdge() || output_node->IsExit()) continue;
// If output node is not execute node, it must be output node
// to the while loop body.
if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
output_node)) {
return false;
}
}
return true;
}
// Given a TPUCompile node, find all TPUExecute nodes that executes the compiled
// program and its model weight variable inputs as well.
// TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata`
// if new reshard ops are added.
Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph,
const std::unordered_set<Node*>& loop_nodes, // NOLINT
std::vector<ExecuteNodeInfo>* execute_node_info,
TPUCompileMetadataProto* new_metadata) {
string metadata_string;
TF_RETURN_IF_ERROR(
GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string));
new_metadata->ParsePartialFromString(metadata_string);
if (new_metadata->num_cores_per_replica() != 1) {
// We do not support model parallelism yet.
return Status::OK();
}
execute_node_info->clear();
for (Node* node : compile_node->out_nodes()) {
if (node->type_string() == "TPUExecute") {
execute_node_info->push_back({node});
}
}
if (execute_node_info->empty()) {
return Status::OK();
}
TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas())
<< "Number of replicas does not equal number of execute nodes: "
<< new_metadata->num_replicas() << " vs " << execute_node_info->size();
DataTypeVector arg_types;
TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(),
"Targs", &arg_types));
for (int64 i = 0; i < arg_types.size(); ++i) {
if (arg_types[i] != DT_RESOURCE) {
continue;
}
const auto sharding_config = new_metadata->args(i).enable_xla_sharding();
if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE &&
sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) {
continue;
}
std::vector<const Edge*> edges(execute_node_info->size());
bool is_supported = true;
std::unordered_map<Node*, absl::flat_hash_set<Node*>>
enter_to_execute_nodes;
for (int64 j = 0; j < edges.size(); ++j) {
auto execute = (*execute_node_info)[j].execute_node;
TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j]));
TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) ==
arg_types[i])
<< "Execute op has an unexpected input type.";
// Traverse backwards to find the Enter node from which the input is
// passed.
// This makes sure that we are checking the usages of all potential
// aliases of the input node as well.
TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput(
edges[j]->src()));
if (enter_node == nullptr) {
is_supported = false;
enter_to_execute_nodes.clear();
break;
}
enter_to_execute_nodes[enter_node].insert(edges[j]->dst());
}
for (const auto& it : enter_to_execute_nodes) {
// Size of execute nodes should be either 1 (per-replica variables) or
// num_replicas (distributed variables).
if ((it.second.size() != 1) &&
(it.second.size() != new_metadata->num_replicas())) {
is_supported = false;
break;
}
TF_ASSIGN_OR_RETURN(bool no_other_use,
ResourceOnlyUsedForTPUExecuteInLoop(
graph, loop_nodes, it.first, it.second));
if (!no_other_use) {
is_supported = false;
break;
}
}
// Add the variable input edges only when they are supported for all
// executes.
if (is_supported) {
for (int64 j = 0; j < edges.size(); ++j) {
(*execute_node_info)[j].var_inputs.push_back(edges[j]);
}
new_metadata->mutable_args(i)->set_enable_xla_sharding(
TPUCompileMetadataProto::Arg::ALLOWED);
}
}
int64 total = 0;
for (const auto& a : new_metadata->args()) {
if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) {
total++;
}
}
TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size())
<< " total " << total << " var_inputs "
<< (*execute_node_info)[0].var_inputs.size();
if (total == 0) {
// We don't need to process anything if no input is added.
execute_node_info->clear();
}
return Status::OK();
}
bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; }
void FindTPUCompileNodes(
const std::string* current_function_name,
const AttrValueMap* current_function_attr,
const std::unordered_map<string, WhileLoopFrame>& frames,
std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
// Adds frames with no children (i.e., the innermost frames) to a worklist.
std::deque<const WhileLoopFrame*> worklist;
for (auto& frame : frames) {
if (frame.second.num_children == 0) {
worklist.push_back(&frame.second);
}
}
// Check TPUCompile node from the innermost while loop to the outermost
// while loop.
while (!worklist.empty()) {
const WhileLoopFrame* frame = worklist.front();
worklist.pop_front();
for (const auto& n : frame->nodes) {
if (!IsTPUCompileOp(*n)) continue;
HostTrainingLoopInfo host_training_loop_info;
host_training_loop_info.compile_node_name = n->name();
host_training_loop_info.loop_cond_node_name = frame->loop_cond->name();
host_training_loop_info.while_loop_name = frame->name;
for (const auto arg : frame->args) {
LoopArgInfo arg_info;
arg_info.enter_node_name = arg.enter->name();
if (arg.exit) arg_info.exit_node_name = arg.exit->name();
host_training_loop_info.loop_arguments.push_back(std::move(arg_info));
}
host_training_loop_info.loop_nodes = frame->nodes;
if (current_function_name) {
host_training_loop_info.encapsulating_function_name =
*current_function_name;
}
if (current_function_attr) {
host_training_loop_info.encapsulating_function_attrs =
*current_function_attr;
}
host_training_loops_info->emplace_back(
std::move(host_training_loop_info));
}
// If the parent has no remaining children, add it to the worklist.
--frame->parent->num_children;
if (frame->parent->num_children == 0) {
worklist.push_back(frame->parent);
}
}
}
// From while loop cond node, finds all loop exit nodes by searching/traversing
// nodes in below pattern of nodes:
// LoopCond -----> Switch -----> Exit
std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) {
std::vector<Node*> loop_exit_nodes;
for (const auto e_cond : loop_cond.out_edges()) {
if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue;
auto switch_node = e_cond->dst();
for (const auto e_switch : switch_node->out_edges()) {
if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue;
loop_exit_nodes.push_back(e_switch->dst());
}
}
return loop_exit_nodes;
}
// Find any one of switch nodes in the while loop by traversing the graph
// from while loop condition node.
xla::StatusOr<Node*> GetLoopSwitchNode(const Node& loop_cond_node) {
Node* loop_switch_node;
for (auto n : loop_cond_node.out_nodes()) {
if (n->IsSwitch()) {
loop_switch_node = n;
break;
}
}
TF_RET_CHECK(loop_switch_node->IsSwitch())
<< "Unable to find any switch nodes.";
return loop_switch_node;
}
// Returns or creates a node in that is executed before each loop iteration
// in the while loop.
Status GetOrCreateBeforeEachIterationNode(Graph* graph, Node* loop_switch_node,
Node** node_out) {
// If while loop switch node already has a outgoing data to true brach
// of the switch op, then reuse that node.
for (const auto out_edge : loop_switch_node->out_edges()) {
if (out_edge->src_output() == 1) {
*node_out = out_edge->dst();
return Status::OK();
}
}
// Create Identity node that represents execution at every loop iteration.
NodeDef at_loop_iteration_nodedef;
at_loop_iteration_nodedef.set_op("Identity");
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
AddNodeAttr("T", dtype, &at_loop_iteration_nodedef);
at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/before_iteration", "/_", internal::GetNodeId())));
Status status;
Node* at_loop_iteration_node =
graph->AddNode(at_loop_iteration_nodedef, &status);
TF_RETURN_IF_ERROR(status);
graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0);
*node_out = at_loop_iteration_node;
return Status::OK();
}
// Injects NoOp node in that is executed after the very last iteration
// of the while loop but before the while loop exit node.
Status AddNoOpAfterLastIteration(Graph* graph, Node* loop_switch_node,
Node** node_out) {
// Find the exit node from loop switch node.
Node* exit_node;
for (const auto out_node : loop_switch_node->out_nodes()) {
if (out_node->IsExit()) {
exit_node = out_node;
break;
}
}
TF_RET_CHECK(exit_node != nullptr)
<< "Cannot find exit node connected to switch node :"
<< loop_switch_node->name();
// Create NoOp that represents execution at the end of while loop
// last iteration.
NodeDef after_last_loop_iteration;
after_last_loop_iteration.set_op("Identity");
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
AddNodeAttr("T", dtype, &after_last_loop_iteration);
after_last_loop_iteration.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
Status status;
Node* after_last_iteration_node =
graph->AddNode(after_last_loop_iteration, &status);
TF_RETURN_IF_ERROR(status);
// Newly created node must be executed once after last iteration of the while
// loop and before while loop exits.
graph->AddEdge(loop_switch_node, 0, after_last_iteration_node, 0);
graph->AddControlEdge(after_last_iteration_node, exit_node);
*node_out = after_last_iteration_node;
return Status::OK();
}
} // namespace
Status DetectHostTrainingLoop(
const std::string* current_function_name,
const AttrValueMap* current_function_attr,
const FunctionLibraryDefinition* library, Graph* graph,
FunctionLibraryRuntime* flr,
std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
std::vector<AssociatedFunctionInfo> associated_function_list;
for (const auto* n : graph->nodes()) {
const auto associated_functions = GetAssociatedFunctions(*n, library);
if (associated_functions.empty()) continue;
associated_function_list.insert(associated_function_list.end(),
associated_functions.begin(),
associated_functions.end());
}
Status ret_status = Status::OK();
for (const auto& function : associated_function_list) {
if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue;
// Convert the function to Graph.
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(),
AttrSlice(&function.attrs()), &handle));
auto cleanup_handle = gtl::MakeCleanup([&]() {
auto s = flr->ReleaseHandle(handle);
if (!s.ok()) {
ret_status.Update(s);
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
Graph* function_graph = body->graph;
TF_RETURN_IF_ERROR(DetectHostTrainingLoop(
&function.func_name(), &function.attrs(), library, function_graph, flr,
host_training_loops_info));
}
// BuildControlFlowInfo() requires that the graph's source node is connected
// to all source nodes in the graph. Many graphs violate this invariant.
// As so, add edges to source/sink nodes so that this invariant is kept.
FixupSourceAndSinkEdges(graph);
std::vector<ControlFlowInfo> cf_info;
TF_RETURN_IF_ERROR(
BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr));
std::unordered_map<string, WhileLoopFrame> frames;
TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
FindTPUCompileNodes(current_function_name, current_function_attr, frames,
host_training_loops_info);
return ret_status;
}
Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
const auto& compile_node_name = host_loop_info.compile_node_name;
const auto node_name_map = graph->BuildNodeNameIndex();
const auto node_it = node_name_map.find(compile_node_name);
TF_RET_CHECK(node_it != node_name_map.end())
<< "Unable to find compile node : " << compile_node_name;
const auto compile_node = node_it->second;
std::vector<ExecuteNodeInfo> execute_nodes_info;
Status status;
TPUCompileMetadataProto metadata;
status =
ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes,
&execute_nodes_info, &metadata);
if (!status.ok()) {
LOG(ERROR) << "Encountered error when trying to extract execute nodes, "
"skipping host loop optimization. Status: "
<< status.ToString();
return Status::OK();
}
if (execute_nodes_info.empty()) {
return Status::OK();
}
// Update the TPUCompileMetadata such that sharding config of the
// sharded resource variable inputs is set to ALLOWED instead of
// TENTATIVE.
string new_metadata_string;
metadata.SerializeToString(&new_metadata_string);
compile_node->ClearAttr("metadata");
compile_node->AddAttr("metadata", new_metadata_string);
// Unsharding of the model weight variables must happen only at the very
// last loop iteration. As so, add while loop condition predicate as an
// input to the sharding switch node. If loop condition is true, we do not
// unshard.
const auto& cond_node_name = host_loop_info.loop_cond_node_name;
auto loop_cond_node_it = node_name_map.find(cond_node_name);
TF_RET_CHECK(loop_cond_node_it != node_name_map.end())
<< "Cannot find loop condition node : " << cond_node_name;
auto* loop_condition_node = loop_cond_node_it->second;
// In order to make sure that shard/unshard operations are invoked
// at the start of every loop body and at the end of last iteration
// of the loop, respectively, traverse the graph and find a switch node
// of the host training loop.
TF_ASSIGN_OR_RETURN(Node * switch_node,
GetLoopSwitchNode(*loop_condition_node));
Node* after_last_iteration_node;
TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(graph, switch_node,
&after_last_iteration_node));
Node* before_loop_iteration_node;
TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode(
graph, switch_node, &before_loop_iteration_node));
// Create const op that represents default sharding value
// (i.e. no-op sharding).
NodeDef default_sharding;
default_sharding.set_op("Const");
default_sharding.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
AddNodeAttr("dtype", DT_STRING, &default_sharding);
Tensor t(DT_STRING, {2});
t.vec<tstring>()(0) = kDefaultShardingValue;
t.vec<tstring>()(1) = kDefaultShardingValue;
t.AsProtoTensorContent(
(*default_sharding.mutable_attr())["value"].mutable_tensor());
Node* default_sharding_node = graph->AddNode(default_sharding, &status);
TF_RETURN_IF_ERROR(status);
// Add control edge between loop condition to make sure that
// default_sharding_node node is inside the while loop frame.
graph->AddControlEdge(loop_condition_node, default_sharding_node);
// Build a no-op node used to add control edges after unshard nodes.
NodeDef after_unshard;
after_unshard.set_op("NoOp");
after_unshard.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
auto after_unshard_node = graph->AddNode(after_unshard, &status);
TF_RETURN_IF_ERROR(status);
for (auto info : execute_nodes_info) {
auto execute_node = info.execute_node;
// Create Reshard op that optionally shards model weight variables
// prior to program execution.
NodeDef reshard_node_def;
reshard_node_def.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/reshard", "/_", internal::GetNodeId())));
reshard_node_def.set_op("TPUReshardVariables");
AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
&reshard_node_def);
Node* reshard_op_node = graph->AddNode(reshard_node_def, &status);
if (!status.ok()) return status;
reshard_op_node->set_assigned_device_name(
execute_node->assigned_device_name());
// Reshard op must execute at every loop iteration prior to
// TPUExecute node.
graph->AddControlEdge(before_loop_iteration_node, reshard_op_node);
graph->AddControlEdge(reshard_op_node, execute_node);
for (int i = 0; i < info.var_inputs.size(); ++i) {
const auto variable_edge = info.var_inputs[i];
graph->AddEdge(variable_edge->src(), variable_edge->src_output(),
reshard_op_node, i);
}
const int new_key_input = info.var_inputs.size();
// Add program input edge from the compiler(i.e. compilation key).
const auto compilation_key_edge =
FindEdgeConnecting(compile_node, execute_node);
graph->AddEdge(compile_node, compilation_key_edge->src_output(),
reshard_op_node, new_key_input);
// Create VarHandleOp to store sharding state. Sharding state holds string
// compilation key that identifies whether the graph is re-compiled and the
// variables need to be sharded again.
NodeDef var_handle_def;
var_handle_def.set_op("VarHandleOp");
var_handle_def.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/reshard_state", "/_", internal::GetNodeId())));
AddNodeAttr("dtype", DT_STRING, &var_handle_def);
AddNodeAttr("shape", TensorShape({}), &var_handle_def);
Node* var_handle_node = graph->AddNode(var_handle_def, &status);
if (!status.ok()) return status;
// Add control edge between `var_handle_def` node and while loop
// loop condition so that `var_handle_def` is inside the same while loop
// frame.
// TODO(hongjunchoi): Consider adding control edge from another node--such
// as input control node.
graph->AddControlEdge(loop_condition_node, var_handle_node);
// Connect data edge between var handle op and reshard op.
const int format_state_input = new_key_input + 1;
graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input);
// Create Reshard op that represents unsharding after TPUExecute.
NodeDef unshard_node_def;
unshard_node_def.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/unshard", "/_", internal::GetNodeId())));
unshard_node_def.set_op("TPUReshardVariables");
AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
&unshard_node_def);
Node* unshard_op_node = graph->AddNode(unshard_node_def, &status);
TF_RETURN_IF_ERROR(status);
unshard_op_node->set_assigned_device_name(
execute_node->assigned_device_name());
for (int i = 0; i < info.var_inputs.size(); ++i) {
const auto variable_edge = info.var_inputs[i];
// Connect model weight resource variables to unshard op. Since unshard op
// must be only invoked after the very last loop iteration, for each while
// loop inputs, we traverse backwards to find the switch node of the host
// training loop and connect `output_false` field of the switch node with
// unshard op.
TF_ASSIGN_OR_RETURN(
Node * enter_node,
FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src()));
graph->AddEdge(enter_node, 0, unshard_op_node, i);
}
// Add control dependency before/after unshard node and the control nodes.
graph->AddControlEdge(after_last_iteration_node, unshard_op_node);
graph->AddControlEdge(unshard_op_node, after_unshard_node);
graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input);
// Add data edge from sharding state var handle op to unshard op.
graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input);
}
// Add control dependency from after_unshard_node to all exits nodes. This is
// to make sure that the unshard ops will be executed as long as any of the
// exits are used.
for (auto exit : FindLoopExitNodes(*loop_condition_node)) {
graph->AddControlEdge(after_unshard_node, exit);
}
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,80 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
#include <string>
#include <unordered_set>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
namespace tpu {
struct LoopArgInfo {
std::string enter_node_name;
// Exit nodes are optional for loop invariant while loop args.
absl::optional<std::string> exit_node_name;
};
struct HostTrainingLoopInfo {
// Name and attribute information about the function in which
// host training loop is included. If host training loop is not
// inside a function call, then `function_name` and `function_attrs`
// are nullopt.
absl::optional<std::string> encapsulating_function_name;
absl::optional<AttrValueMap> encapsulating_function_attrs;
// TPU Compile node as within a host training loop.
std::string compile_node_name;
// Name of the while loop in which TPU compile op is located.
std::string while_loop_name;
// Name of the node that represents loop condition.
std::string loop_cond_node_name;
// Exit and Enter node names for each loop arguments.
std::vector<LoopArgInfo> loop_arguments;
std::unordered_set<Node*> loop_nodes; // NOLINT
};
// Walks through the `graph`, recursively if functional nodes exist, and
// identifies all host training loops. Host training loops are the inner
// most while loops that encapsulates TPUCompileOp node. This would be
// later used/analyzed to inroduce host loop specific optimizations such
// as adding sharded weight update.
Status DetectHostTrainingLoop(
const std::string* current_function_name,
const AttrValueMap* current_function_attr,
const FunctionLibraryDefinition* library, Graph* graph,
FunctionLibraryRuntime* flr,
std::vector<HostTrainingLoopInfo>* host_training_loops_info);
// Injects VariableReshardOps to before and after TPUExecute op inside
// host training loop body. This effectively applies sharded weight update
// on model weight variables.
Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info);
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
namespace tensorflow {
IncompleteNodeDefBuilder::IncompleteNodeDefBuilder(const string& name,
const string& op,
const NodeDebugInfo& debug) {
nodedef_.set_name(name);
nodedef_.set_op(op);
MergeDebugInfo(debug, &nodedef_);
}
IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::AddAttr(
const string& attr, const DataType& type) {
AddNodeAttr(attr, type, &nodedef_);
return *this;
}
IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::AddAttr(const string& attr,
int val) {
AddNodeAttr(attr, val, &nodedef_);
return *this;
}
IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::Device(
const string& device) {
nodedef_.set_device(device);
return *this;
}
Status IncompleteNodeDefBuilder::Build(Graph* graph, Node** n) {
Status status;
*n = graph->AddNode(nodedef_, &status);
return status;
}
IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Identity(
const string& name, const DataType& type, const NodeDebugInfo& debug) {
return IncompleteNodeDefBuilder(name, "Identity", debug).AddAttr("T", type);
}
IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Merge(
const string& name, const DataType& type, const NodeDebugInfo& debug,
int n) {
return IncompleteNodeDefBuilder(name, "Merge", debug)
.AddAttr("T", type)
.AddAttr("N", n);
}
IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Switch(
const string& name, const DataType& type, const NodeDebugInfo& debug) {
return IncompleteNodeDefBuilder(name, "Switch", debug).AddAttr("T", type);
}
} // namespace tensorflow

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_
#include <string>
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Convenience builder to build NodeDefs without specifying the inputs. This is
// similar to NodeDefBuilder except inputs are not specified.
// TODO(jpienaar): Clean up NodeDefBuilder and remove this class.
class IncompleteNodeDefBuilder {
public:
IncompleteNodeDefBuilder(const string& name, const string& op,
const NodeDebugInfo& debug);
IncompleteNodeDefBuilder& AddAttr(const string& attr, const DataType& type);
IncompleteNodeDefBuilder& AddAttr(const string& attr, int val);
IncompleteNodeDefBuilder& Device(const string& device);
Status Build(Graph* graph, Node** n);
static IncompleteNodeDefBuilder Identity(const string& name,
const DataType& type,
const NodeDebugInfo& debug);
static IncompleteNodeDefBuilder Merge(const string& name,
const DataType& type,
const NodeDebugInfo& debug, int n);
static IncompleteNodeDefBuilder Switch(const string& name,
const DataType& type,
const NodeDebugInfo& debug);
private:
NodeDef nodedef_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
#include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
#include "tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h"
@ -30,8 +31,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 34,
EncapsulateTPUComputationsPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 39,
ExtractOutsideCompilationPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 40,
DistributedTPURewritePass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 0,
VariableMergerPass);
} // namespace
} // namespace tensorflow

View File

@ -55,8 +55,8 @@ class MultiPlatformManagerImpl {
TF_LOCKS_EXCLUDED(mu_);
port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
const std::function<bool(const Platform*)>& filter)
TF_LOCKS_EXCLUDED(mu_);
const std::function<bool(const Platform*)>& filter,
bool initialize_platform) TF_LOCKS_EXCLUDED(mu_);
using Listener = MultiPlatformManager::Listener;
port::Status RegisterListener(std::unique_ptr<Listener> listener)
@ -188,7 +188,8 @@ port::Status MultiPlatformManagerImpl::RegisterListener(
port::StatusOr<std::vector<Platform*>>
MultiPlatformManagerImpl::PlatformsWithFilter(
const std::function<bool(const Platform*)>& filter) {
const std::function<bool(const Platform*)>& filter,
bool initialize_platform) {
absl::MutexLock lock(&mu_);
CHECK_EQ(id_map_.size(), name_map_.size());
std::vector<Platform*> platforms;
@ -196,7 +197,7 @@ MultiPlatformManagerImpl::PlatformsWithFilter(
for (const auto& entry : id_map_) {
Platform* platform = entry.second;
if (filter(platform)) {
if (!platform->Initialized()) {
if (initialize_platform && !platform->Initialized()) {
SE_RETURN_IF_ERROR(platform->Initialize({}));
}
platforms.push_back(platform);
@ -299,7 +300,14 @@ MultiPlatformManager::InitializePlatformWithId(
/*static*/ port::StatusOr<std::vector<Platform*>>
MultiPlatformManager::PlatformsWithFilter(
const std::function<bool(const Platform*)>& filter) {
return Impl().PlatformsWithFilter(filter);
return PlatformsWithFilter(filter, /*initialize_platform=*/true);
}
/*static*/ port::StatusOr<std::vector<Platform*>>
MultiPlatformManager::PlatformsWithFilter(
const std::function<bool(const Platform*)>& filter,
bool initialize_platform) {
return Impl().PlatformsWithFilter(filter, initialize_platform);
}
} // namespace stream_executor

View File

@ -130,6 +130,10 @@ class MultiPlatformManager {
static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
const std::function<bool(const Platform*)>& filter);
static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
const std::function<bool(const Platform*)>& filter,
bool initialize_platform);
// Although the MultiPlatformManager "owns" its platforms, it holds them as
// undecorated pointers to prevent races during program exit (between this
// object's data and the underlying platforms (e.g., CUDA, OpenCL).

View File

@ -331,6 +331,7 @@ cc_library(
name = "tpu_topology_external",
srcs = ["tpu_topology.cc"],
hdrs = ["tpu_topology.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_decl",
"//tensorflow/core/platform:types",

View File

@ -23,10 +23,11 @@ namespace tensorflow {
namespace tpu {
namespace {
TpuPlatformInterface* GetRegisteredPlatformStatic() {
TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform) {
// Prefer TpuPlatform if it's registered.
auto status_or_tpu_platform =
stream_executor::MultiPlatformManager::PlatformWithName("TPU");
stream_executor::MultiPlatformManager::PlatformWithName(
"TPU", initialize_platform);
if (status_or_tpu_platform.ok()) {
return static_cast<TpuPlatformInterface*>(
status_or_tpu_platform.ValueOrDie());
@ -43,7 +44,8 @@ TpuPlatformInterface* GetRegisteredPlatformStatic() {
[](const stream_executor::Platform* platform) {
return dynamic_cast<const TpuPlatformInterface*>(platform) !=
nullptr;
});
},
initialize_platform);
if (!status_or_other_tpu_platforms.ok()) {
LOG(WARNING) << "Error when getting other TPU platforms: "
<< status_or_tpu_platform.status();
@ -64,9 +66,24 @@ TpuPlatformInterface* GetRegisteredPlatformStatic() {
/* static */
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
// Use a local static variable to avoid data races during initialization.
return GetRegisteredPlatform(/*initialize_platform=*/true);
}
/* static */
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform(
bool initialize_platform) {
static bool requested_initialize_platform = initialize_platform;
static TpuPlatformInterface* tpu_registered_platform =
GetRegisteredPlatformStatic();
GetRegisteredPlatformStatic(initialize_platform);
if (!requested_initialize_platform && initialize_platform) {
// If the first time this function is called, we did not request
// initializing the platform, but the next caller wants the platform
// initialized, we will call GetRegisteredPlatformStatic again to initialize
// the platform.
tpu_registered_platform = GetRegisteredPlatformStatic(initialize_platform);
}
return tpu_registered_platform;
}

View File

@ -33,6 +33,9 @@ class TpuPlatformInterface : public stream_executor::Platform {
// is registered or an error occurred.
static TpuPlatformInterface* GetRegisteredPlatform();
// Option to not initialize a platform if not necessary.
static TpuPlatformInterface* GetRegisteredPlatform(bool initialize_platform);
virtual Status Reset() { return Reset(false); }
virtual Status Reset(bool only_tear_down) = 0;

View File

@ -30,6 +30,7 @@ struct TpuChipCoordinatesExternal {
class TpuCoreLocationExternal {
public:
TpuCoreLocationExternal() : core_location_(nullptr) {}
explicit TpuCoreLocationExternal(void* core_location)
: core_location_(core_location) {}
TpuChipCoordinatesExternal chip_coordinates() const;