Open source distributed_tpu_rewrite_pass.cc and associated helper methods
PiperOrigin-RevId: 322460893 Change-Id: I8ca6164e8c4ce2b6d6e79db66fbb028305634ca5
This commit is contained in:
parent
318340f1cf
commit
145d21a90d
@ -13,6 +13,7 @@ cc_library(
|
||||
srcs = ["tpu_rewrite_pass_registration.cc"],
|
||||
deps = [
|
||||
":distributed_tpu_configuration_rewrite_pass",
|
||||
":distributed_tpu_rewrite_pass",
|
||||
":encapsulate_tpu_computations_pass",
|
||||
":variable_merger_pass",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -101,3 +102,120 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "distributed_tpu_rewrite_pass_internal",
|
||||
srcs = ["distributed_tpu_rewrite_pass_internal.cc"],
|
||||
hdrs = ["distributed_tpu_rewrite_pass_internal.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/random",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "distributed_tpu_rewrite_pass",
|
||||
srcs = [
|
||||
"distributed_tpu_rewrite_pass.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"distributed_tpu_rewrite_pass.h",
|
||||
],
|
||||
deps = [
|
||||
":cond_builder",
|
||||
":distributed_tpu_rewrite_helpers",
|
||||
":distributed_tpu_rewrite_pass_internal",
|
||||
":host_training_loop_optimization_util",
|
||||
":incomplete_nodedef_builder",
|
||||
"//tensorflow/compiler/jit:encapsulate_util",
|
||||
"//tensorflow/compiler/jit:shape_inference",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/compiler/tf2xla:sharding_util",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:sharding_builder",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/core/common_runtime:function",
|
||||
"//tensorflow/core/common_runtime:graph_constructor",
|
||||
"//tensorflow/core/common_runtime:lower_function_call_op",
|
||||
"//tensorflow/core/common_runtime:lower_functional_ops",
|
||||
"//tensorflow/core/common_runtime:lower_if_op",
|
||||
"//tensorflow/core/common_runtime:lower_while_op",
|
||||
"//tensorflow/core/common_runtime:optimization_registry",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_compile_interface",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_topology_external",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "incomplete_nodedef_builder",
|
||||
srcs = ["incomplete_nodedef_builder.cc"],
|
||||
hdrs = ["incomplete_nodedef_builder.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cond_builder",
|
||||
srcs = ["cond_builder.cc"],
|
||||
hdrs = ["cond_builder.h"],
|
||||
deps = [
|
||||
":incomplete_nodedef_builder",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "host_training_loop_optimization_util",
|
||||
srcs = [
|
||||
"host_training_loop_optimization_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"host_training_loop_optimization_util.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":distributed_tpu_rewrite_pass_internal",
|
||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow_util",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:node_hash_set",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
83
tensorflow/core/tpu/graph_rewrite/cond_builder.cc
Normal file
83
tensorflow/core/tpu/graph_rewrite/cond_builder.cc
Normal file
@ -0,0 +1,83 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/graph_rewrite/cond_builder.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
CondBuilder::CondBuilder(string name, string device, const NodeDebugInfo& debug,
|
||||
Graph* graph)
|
||||
: graph_(graph), name_(std::move(name)), device_(std::move(device)) {
|
||||
auto new_name = [graph, this](string suffix) {
|
||||
return graph->NewName(strings::StrCat(name_, "/", suffix));
|
||||
};
|
||||
TF_CHECK_OK(
|
||||
IncompleteNodeDefBuilder::Identity(new_name("pred"), DT_BOOL, debug)
|
||||
.Device(device_)
|
||||
.Build(graph_, &pred_));
|
||||
Node* switch_pred;
|
||||
TF_CHECK_OK(
|
||||
IncompleteNodeDefBuilder::Switch(new_name("switch_pred"), DT_BOOL, debug)
|
||||
.Device(device_)
|
||||
.Build(graph_, &switch_pred));
|
||||
graph_->AddEdge(pred(), 0, switch_pred, 0);
|
||||
graph_->AddEdge(pred(), 0, switch_pred, 1);
|
||||
TF_CHECK_OK(
|
||||
IncompleteNodeDefBuilder::Identity(new_name("switch_f"), DT_BOOL, debug)
|
||||
.Device(device_)
|
||||
.Build(graph_, &switch_f_));
|
||||
TF_CHECK_OK(
|
||||
IncompleteNodeDefBuilder::Identity(new_name("switch_t"), DT_BOOL, debug)
|
||||
.Device(device_)
|
||||
.Build(graph_, &switch_t_));
|
||||
graph_->AddEdge(switch_pred, kElseBranch, switch_f_, 0);
|
||||
graph_->AddEdge(switch_pred, kThenBranch, switch_t_, 0);
|
||||
Node* merge_pred;
|
||||
TF_CHECK_OK(IncompleteNodeDefBuilder::Merge(new_name("merge_pred"), DT_BOOL,
|
||||
debug, /*n=*/2)
|
||||
.Device(device_)
|
||||
.Build(graph_, &merge_pred));
|
||||
graph_->AddEdge(switch_f_, 0, merge_pred, kElseBranch);
|
||||
graph_->AddEdge(switch_t_, 0, merge_pred, kThenBranch);
|
||||
// Note: when additional return values are added then there should be a
|
||||
// control dependency between those merge nodes and control_successor_ to
|
||||
// ensure that it is control successor of conditional.
|
||||
control_successor_ = merge_pred;
|
||||
}
|
||||
|
||||
Node* CondBuilder::pred() { return pred_; }
|
||||
|
||||
Node* CondBuilder::switch_f() { return switch_f_; }
|
||||
|
||||
Node* CondBuilder::switch_t() { return switch_t_; }
|
||||
|
||||
Node* CondBuilder::control_successor() { return control_successor_; }
|
||||
|
||||
Status CondBuilder::AddInput(const string& input_name, const DataType& type,
|
||||
const string& device, const NodeDebugInfo& debug,
|
||||
Node** input) {
|
||||
auto b = IncompleteNodeDefBuilder::Switch(
|
||||
graph_->NewName(strings::StrCat(name_, "/", input_name)), type, debug);
|
||||
TF_RETURN_IF_ERROR(b.Device(device).Build(graph_, input));
|
||||
graph_->AddEdge(pred(), 0, *input, 1);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
74
tensorflow/core/tpu/graph_rewrite/cond_builder.h
Normal file
74
tensorflow/core/tpu/graph_rewrite/cond_builder.h
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
|
||||
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Conditional builder.
|
||||
// Convenience builder to make it easy to construct a conditional. E.g.,
|
||||
// Node* pred = ...;
|
||||
// CondBuilder cb("cond", g);
|
||||
// auto switch_var = cb.AddInput("var", DT_RESOURCE);
|
||||
// g->AddEdge(pred, 0, cb.pred(), 0);
|
||||
// Will create the nodes of a conditional that takes as input a resource
|
||||
// variable ("var") as input and that switches on pred.
|
||||
//
|
||||
// This currently only handles the case needed by distributed_tpu_rewrite_pass
|
||||
// and is not completely general.
|
||||
class CondBuilder {
|
||||
public:
|
||||
enum Branch { kElseBranch = 0, kThenBranch = 1 };
|
||||
|
||||
CondBuilder(string name, string device, const NodeDebugInfo& debug,
|
||||
Graph* graph);
|
||||
|
||||
// Returns node corresponding to the predicate input.
|
||||
Node* pred();
|
||||
|
||||
// Returns node corresponding to switch_f branch of predicate switch.
|
||||
Node* switch_f();
|
||||
|
||||
// Returns node corresponding to switch_t branch of predicate switch.
|
||||
Node* switch_t();
|
||||
|
||||
// Returns node corresponding to control successor.
|
||||
Node* control_successor();
|
||||
|
||||
// Returns the Switch node to feed a value of the given type into the
|
||||
// conditional.
|
||||
Status AddInput(const string& input_name, const DataType& type,
|
||||
const string& device, const NodeDebugInfo& debug,
|
||||
Node** input);
|
||||
|
||||
private:
|
||||
Node* control_successor_;
|
||||
Node* switch_f_;
|
||||
Node* switch_t_;
|
||||
Node* pred_;
|
||||
Graph* const graph_;
|
||||
const string name_;
|
||||
const string device_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_
|
4105
tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
Normal file
4105
tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc
Normal file
File diff suppressed because it is too large
Load Diff
589
tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
Normal file
589
tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h
Normal file
@ -0,0 +1,589 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Rewrites TPUReplicate nodes into replicated computations on TPU.
|
||||
//
|
||||
// To represent a distributed TPU computation, we use the
|
||||
// TPUReplicate operator, that describes a subgraph (represented as a
|
||||
// Tensorflow function) to replicate across a TPU pod.
|
||||
//
|
||||
// Model parallelism and data parallelism:
|
||||
// ---------------------------------------
|
||||
// We support two different kinds of parallelism on TPU:
|
||||
// * data parallelism (replication), or parallelization across batches, and
|
||||
// * model parallelism, or parallelization within a batch.
|
||||
//
|
||||
// The function passed to a TPUReplicate operator is replicated many
|
||||
// times across a TPU pod (data parallelism). The `num_replicas` attribute
|
||||
// controls how many replicas of the computation to create. Replicas are mostly
|
||||
// independent; replicas can only communicate using the CrossReplicaSum
|
||||
// operator, which is typically used to communicate gradients during training.
|
||||
//
|
||||
// Each replica may optionally use more than one TPU core (model
|
||||
// parallelism). The `num_cores_per_replica` attribute controls how many cores
|
||||
// there are per replica. For each core, there is a virtual TPU_REPLICATED_CORE
|
||||
// device that is only valid within replicated TPU computations (e.g.,
|
||||
// TPU_REPLICATED_CORE:0, TPU_REPLICATED_CORE:1, etc.); each TPU_REPLICATED_CORE
|
||||
// device corresponds to one TPU core in every replica.
|
||||
// Each replica has runs its own copy of the computation assigned to each
|
||||
// TPU_REPLICATED_CORE device.
|
||||
//
|
||||
// The Python code is responsible for providing a device_assignment that
|
||||
// describes how the replicated logical cores map to physical cores on the TPU
|
||||
// topology.
|
||||
//
|
||||
// Inputs to TPUReplicate:
|
||||
// ------------------------------
|
||||
// The TPUReplicate operator takes three kinds of inputs, in the
|
||||
// following order:
|
||||
// * per-replica inputs. If there are three per-replica inputs (A, B, C) and two
|
||||
// replicas, the first six arguments to TPUReplicate will be:
|
||||
// A0 B0 C0 A1 B1 C1
|
||||
// where Ai is the A input to the i-th replica.
|
||||
// * distributed inputs. These inputs follow the per-replica inputs.
|
||||
// If there are two distributed inputs (E, F) and two replicas, the following
|
||||
// arguments to TPUReplicate will be: E F.
|
||||
// But there is local E and F on each replica.
|
||||
// * broadcast inputs. These inputs follow the distributed inputs. All
|
||||
// replicas receive a copy of each of these inputs.
|
||||
// * variables. Resource variables accessed by the computation follow the
|
||||
// broadcast inputs.
|
||||
//
|
||||
// For example, for a computation with two replicas, three per-replica inputs
|
||||
// (A, B, C), two distributed inputs(E, F), two broadcast inputs (X, Y), and two
|
||||
// variables (V, W), the arguments to TPUReplicate will be:
|
||||
// A0 B0 C0 A1 B1 C1 E F X Y V W
|
||||
// and each replica will receive the following arguments:
|
||||
// A B C E F X Y V W
|
||||
//
|
||||
// Distributed TPU compilation requires that the shapes of all operators
|
||||
// be known statically at compilation time, before any nodes have executed.
|
||||
// Shapes are determined using shape information emitted by InferShapes. It
|
||||
// is not possible to replicate Tensorflow operators with unknown or dynamic
|
||||
// shapes for TPU at present.
|
||||
//
|
||||
// Graph rewrite:
|
||||
// --------------
|
||||
// Compilation replaces TPUReplicate operators with:
|
||||
// * a single TPUCompile node that compiles the computations,
|
||||
// * one TPUExecute node for each TPU device in the system that
|
||||
// executes the relevant computation,
|
||||
// * one ReadVariableOp for each variable accessed by the replicated
|
||||
// computation,
|
||||
// * one AssignVariableOp for each variable accessed by the replicated
|
||||
// computation. An assignment is built even if a variable is only read by the
|
||||
// computation. We do not know which variables are written until we apply the
|
||||
// XlaCompiler to the computation, but that does not happen until after the
|
||||
// rewrite. Conservatively, we write back the values of all variables after
|
||||
// the computation completes.
|
||||
// TODO(phawkins): only write back variables that the computation may write.
|
||||
// * one Shape node for each Tensor or Variable input to the computation whose
|
||||
// shape is not statically known at rewrite time. The input shapes are fed
|
||||
// to the TPUCompile node.
|
||||
//
|
||||
// To ensure that the reads and writes seem to happen at the right time in the
|
||||
// graph execution, we add control edges from all predecessors of the original
|
||||
// TPUReplicate operator to each of the ReadVariableOp operators.
|
||||
// Similarly, we add control edges from all of the AssignVariableOp operators to
|
||||
// all of the successors of the TPUReplicate operator.
|
||||
//
|
||||
// The TPUReplicate rewrite must run before placement, since resource
|
||||
// variable inputs will have DT_RESOURCE, which cannot be sent across devices,
|
||||
// leading to objections from the placer. The rewrite rewrites the resource
|
||||
// accesses into explicit ReadVariableOp and AssignVariableOp operators that the
|
||||
// placer is free to colocate with the variables.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
|
||||
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_topology.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Replaces clusters assigned to TPU_SYSTEM devices with
|
||||
// TPUCompile and TPUExecute nodes assigned to the corresponding
|
||||
// TPU devices.
|
||||
class DistributedTPURewritePass : public GraphOptimizationPass {
|
||||
public:
|
||||
static void SetDistributedTpuRewritePassOptions(
|
||||
bool distribute_vars,
|
||||
bool replicate_inputs_outputs_by_default_for_xla_spmd,
|
||||
bool enable_cross_replica_sharding_mirrored_variables,
|
||||
bool enable_automatic_model_parallelism);
|
||||
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
||||
// The following methods are public only for the use of unit tests.
|
||||
|
||||
// See comment at the top of the file for how the inputs are ordered.
|
||||
// Encapsulates the different TPU replicated node input and output
|
||||
// information, and provide common APIs over them.
|
||||
class ParameterInfo {
|
||||
public:
|
||||
ParameterInfo() {}
|
||||
ParameterInfo(int64 num_replicas, int64 num_per_replica_args,
|
||||
int64 num_distributed_args, int64 num_broadcast_args,
|
||||
int64 num_variables, int64 num_guaranteed_constants,
|
||||
int64 num_retvals_per_replica)
|
||||
: num_replicas_(num_replicas),
|
||||
num_per_replica_args_(num_per_replica_args),
|
||||
num_distributed_args_(num_distributed_args),
|
||||
num_broadcast_args_(num_broadcast_args),
|
||||
num_variables_(num_variables),
|
||||
num_guaranteed_constants_(num_guaranteed_constants),
|
||||
num_retvals_per_replica_(num_retvals_per_replica) {}
|
||||
|
||||
int64 NumReplicas() const { return num_replicas_; }
|
||||
|
||||
int64 NumPerReplicaArgs() const { return num_per_replica_args_; }
|
||||
|
||||
int64 NumDistributedArgs() const { return num_distributed_args_; }
|
||||
|
||||
int64 NumBroadcastArgs() const { return num_broadcast_args_; }
|
||||
|
||||
int64 NumVariables() const { return num_variables_; }
|
||||
|
||||
int64 NumGuaranteedConstants() const { return num_guaranteed_constants_; }
|
||||
|
||||
int64 NumRetvalsPerReplica() const { return num_retvals_per_replica_; }
|
||||
|
||||
bool IsPerReplicaArg(int64 index) const {
|
||||
return index < num_per_replica_args_;
|
||||
}
|
||||
|
||||
bool IsDistributedArg(int64 index) const {
|
||||
return index >= num_per_replica_args_ &&
|
||||
index < (num_per_replica_args_ + num_distributed_args_);
|
||||
}
|
||||
|
||||
bool IsBroadcastArg(int64 index) const {
|
||||
return index >= num_per_replica_args_ &&
|
||||
index < (num_per_replica_args_ + num_distributed_args_ +
|
||||
num_broadcast_args_);
|
||||
}
|
||||
|
||||
bool IsVariableArg(int64 index) const {
|
||||
return index >= (num_per_replica_args_ + num_broadcast_args_) &&
|
||||
index < (num_per_replica_args_ + num_distributed_args_ +
|
||||
num_broadcast_args_ + num_variables_);
|
||||
}
|
||||
|
||||
bool IsConstantArg(int64 index) const {
|
||||
return index >= (num_per_replica_args_ + num_distributed_args_ +
|
||||
num_broadcast_args_ + num_variables_) &&
|
||||
index < (num_per_replica_args_ + num_distributed_args_ +
|
||||
num_broadcast_args_ + num_variables_ +
|
||||
num_guaranteed_constants_);
|
||||
}
|
||||
|
||||
// Returns the number of inputs which has been received by the host.
|
||||
int64 NumInputsFromHost() const {
|
||||
return num_replicas_ * num_per_replica_args_ + num_distributed_args_ +
|
||||
num_broadcast_args_ + num_variables_ + num_guaranteed_constants_;
|
||||
}
|
||||
|
||||
// Returns the number of inputs which will be sent to each replica.
|
||||
int64 NumInputsToEachReplica() const {
|
||||
return num_per_replica_args_ + num_distributed_args_ +
|
||||
num_broadcast_args_ + num_variables_ + num_guaranteed_constants_;
|
||||
}
|
||||
|
||||
// Returns the total number of output values returned to the host (for all
|
||||
// replicas).
|
||||
int64 NumOutputsToHost() const {
|
||||
return num_replicas_ * num_retvals_per_replica_;
|
||||
}
|
||||
|
||||
// Returns the position of the first per-replica argument, within the set
|
||||
// of all hosts arguments.
|
||||
// Broadcast arguments follow the distributed arguments.
|
||||
int64 FirstBroadcastArgFromHost() const {
|
||||
return num_replicas_ * num_per_replica_args_ + num_distributed_args_;
|
||||
}
|
||||
|
||||
// Indices of mirrored variables across replicas, which should be
|
||||
// categorized as per_replica_args.
|
||||
const std::set<int64>& mirrored_variable_indices() const {
|
||||
return mirrored_variable_indices_;
|
||||
}
|
||||
std::set<int64>* mutable_mirrored_variable_indices() {
|
||||
return &mirrored_variable_indices_;
|
||||
}
|
||||
|
||||
private:
|
||||
int64 num_replicas_ = 1;
|
||||
int64 num_per_replica_args_ = 0;
|
||||
int64 num_distributed_args_ = 0;
|
||||
int64 num_broadcast_args_ = 0;
|
||||
int64 num_variables_ = 0;
|
||||
int64 num_guaranteed_constants_ = 0;
|
||||
int64 num_retvals_per_replica_ = 0;
|
||||
std::set<int64> mirrored_variable_indices_;
|
||||
};
|
||||
|
||||
// Mapping from TPUReplicate cluster name to tpu device names. Value is a
|
||||
// mapping from [replica][core] to a TF device name.
|
||||
typedef absl::flat_hash_map<string, std::vector<std::vector<string>>>
|
||||
TPUReplicateDeviceNamesMapping;
|
||||
|
||||
// Determines which devices to use to run the computation.
|
||||
// Inputs:
|
||||
// * num_tpus_per_task: the number of TPU devices attached to each task
|
||||
// * tpu_devices: a [task][device] collection of TPU devices
|
||||
// * num_replicas: the number of replicas requested
|
||||
// * num_cores_per_replica: the number of cores in each computation instance
|
||||
// * topology_attr: the topology TPUReplicate attribute
|
||||
// * device_assignment_attr: the device_assignment TPUReplicate attribute
|
||||
// Outputs:
|
||||
// * tf_device_assignment: a mapping from [replica][core] to a TF device name
|
||||
// * xla_device_assignment: a mapping from [replica][core] to a linearized TPU
|
||||
// coordinate.
|
||||
// TODO(phawkins): change tf_device_assignment to an xla::Array2D.
|
||||
static Status BuildDeviceAssignment(
|
||||
const tpu::TpuTopologyExternal& topology, int num_tpus_per_task,
|
||||
const std::vector<std::vector<Device*>>& tpu_devices, int num_replicas,
|
||||
int num_cores_per_replica, const string& topology_attr,
|
||||
absl::Span<const int> device_assignment_attr,
|
||||
std::vector<std::vector<string>>* tf_device_assignment,
|
||||
std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment);
|
||||
|
||||
// Returns the `computation` graph attached to TPUReplicate operator
|
||||
// `node`. `flr` is a FunctionLibraryRuntime to use when
|
||||
// instantiating the function body. Sets `*arg_types` and
|
||||
// `*retval_types` to the argument/return types of the function.
|
||||
static Status GetComputationForTPUReplicateOp(const NameAttrList& function,
|
||||
FunctionLibraryRuntime* flr,
|
||||
Graph* computation,
|
||||
DataTypeVector* arg_types,
|
||||
DataTypeVector* retval_types);
|
||||
|
||||
// Returns the shapes of the argument tensors and return values of the
|
||||
// TPUReplicate operator `node` using the _output_shapes,
|
||||
// _output_handle_shapes, and _output_handle_types annotations on the input
|
||||
// nodes. Expects inputs in the following order (see comment at top of file):
|
||||
// * num_replicas * num_per_replica_args per-replica inputs,
|
||||
// * num_broadcast_args broadcast inputs,
|
||||
// * num_variables variable inputs.
|
||||
// Returns an error if the input shapes to `node` are not statically known.
|
||||
// Also verifies that all replicas have identical input shapes for their
|
||||
// per-replica inputs.
|
||||
static Status GetArgAndRetvalShapes(
|
||||
const GraphShapeInfo& shape_info, const Node& node,
|
||||
const ParameterInfo& params_info, std::vector<InferredShape>* arg_shapes,
|
||||
std::vector<InferredShape>* retval_shapes);
|
||||
|
||||
// Assigns arguments and return values to cores. The assignment is represented
|
||||
// as an XLA op sharding, so that an argument can be replicated across cores.
|
||||
// `arg_sharding` and `retval_sharding` are vectors of shardings indexed by
|
||||
// argument/retval number.
|
||||
// `arg_fast_mem` is vector of fast_mem indication which is indexed by
|
||||
// argument number.
|
||||
static Status AssignArgsAndRetvalsToCores(
|
||||
int num_cores_per_replica, const ParameterInfo& params_info,
|
||||
const DataTypeVector& arg_types,
|
||||
const std::vector<InferredShape>& arg_shapes,
|
||||
const DataTypeVector& retval_types,
|
||||
const std::vector<InferredShape>& retval_shapes, const Graph& graph,
|
||||
const Node* replicate_node, FunctionLibraryRuntime* flr,
|
||||
std::vector<::xla::OpSharding>* arg_sharding,
|
||||
std::vector<bool>* arg_fast_mem,
|
||||
std::vector<::xla::OpSharding>* retval_sharding);
|
||||
|
||||
// Computes a fingerprint of the contents of `library`.
|
||||
static Status FingerprintFunctionLibrary(
|
||||
const FunctionLibraryDefinition& library, uint64* fingerprint);
|
||||
|
||||
// Populates `*variables` with the "variables" inputs to `index`-th output of
|
||||
// `node`.
|
||||
struct VariableInput {
|
||||
Node* node;
|
||||
int index;
|
||||
|
||||
// Type of the variable's value. Note that this is different to the type of
|
||||
// the output of 'variable', which is always DT_RESOURCE.
|
||||
DataType dtype;
|
||||
};
|
||||
static Status FindVariableInputs(const Node& node,
|
||||
const NameRangeMap& input_range_map,
|
||||
std::vector<VariableInput>* variables);
|
||||
|
||||
// Populates '*guaranteed_constants' with the "guaranteed_constants" inputs
|
||||
// to 'node'.
|
||||
static Status FindGuaranteedConstantInputs(
|
||||
const Node& node, const NameRangeMap& input_range_map,
|
||||
std::vector<Node*>* guaranteed_constants);
|
||||
|
||||
// Builds Shape nodes that compute the shapes of arguments whose shapes are
|
||||
// not statically known.
|
||||
static Status BuildDynamicShapeNodes(
|
||||
const Node& replicate_node, const std::vector<InferredShape>& arg_shapes,
|
||||
const ParameterInfo& params_info,
|
||||
const std::vector<Node*>& variable_reads, Graph* graph,
|
||||
std::vector<Node*>* dynamic_shape_nodes);
|
||||
|
||||
// Builds a TPUCompile node that compiles the computation in
|
||||
// `function_names`. calls `nodes`.
|
||||
// TODO(b/33943292): at present, for model parallelism with Send/Recv to work
|
||||
// the `nodes` must correspond to the computations assigned to TPU:0,
|
||||
// TPU:1, ... in order since XLA hard-codes the chip IDs in the generated
|
||||
// executables.
|
||||
static Status BuildCompileNode(
|
||||
const Node* replicate_node, const NameAttrList& function,
|
||||
uint64 library_fingerprint, const ParameterInfo& params_info,
|
||||
const std::vector<InferredShape>& arg_shapes,
|
||||
const DataTypeVector& arg_types,
|
||||
const std::vector<Node*>& guaranteed_constant_nodes,
|
||||
const string& session_handle,
|
||||
const std::vector<::xla::OpSharding>& arg_sharding,
|
||||
const std::vector<bool>& arg_fast_mem,
|
||||
const std::vector<::xla::OpSharding>& retval_sharding,
|
||||
int num_cores_per_replica, const string& compile_device,
|
||||
const xla::DeviceAssignment* xla_device_assignment,
|
||||
const std::vector<Node*>& dynamic_shape_nodes, Graph* graph,
|
||||
Node** compile_node, int64 autotuner_thresh);
|
||||
|
||||
// Builds a TPUCompileSucceededAssert node that verifies that compilation
|
||||
// succeeded and replaces the TPUCompilationStatus node in the graph.
|
||||
static Status BuildCompilationStatusReturnNodes(
|
||||
Node* replicate_node, Node* compile_node,
|
||||
Node** control_after_compilation, Graph* graph);
|
||||
|
||||
// Builds ReadVariableOp nodes that read `variables`, with a control
|
||||
// edges that ensure they happen after `control_predecessor`.
|
||||
static Status BuildVariableReads(absl::Span<const VariableInput> variables,
|
||||
Node* control_predecessor, Graph* graph,
|
||||
std::vector<Node*>* variable_reads);
|
||||
|
||||
// Returns true if graph or functions contain resource write op, otherwise
|
||||
// return false.
|
||||
// TODO(b/137048563): Recognize unused resource rewrite op.
|
||||
static bool ContainsResourceWriteOp(const Graph& graph,
|
||||
const FunctionLibraryDefinition& fld);
|
||||
// Struct that describes a variable value to be written back from TPUExecute.
|
||||
struct VariableWrite {
|
||||
// A node:output pair containing a boolean tensor that determines whether
|
||||
// the value should be written back.
|
||||
Node* predicate;
|
||||
int predicate_output;
|
||||
|
||||
// A node:output pair containing the value to be written back.
|
||||
Node* value;
|
||||
int value_output;
|
||||
};
|
||||
|
||||
// Builds AssignVariableOp nodes that write `variables` with the values from
|
||||
// `variable_writes`, with control edges that ensure the writes happen before
|
||||
// `control_successor`.
|
||||
static Status BuildVariableWrites(
|
||||
absl::Span<const VariableInput> variables, Node* control_successor,
|
||||
absl::Span<const VariableWrite> variable_writes, Graph* graph);
|
||||
|
||||
// Builds TPUExecute operators assigned to each TPU device
|
||||
// involved in the computation.
|
||||
// Arguments:
|
||||
// * `params_info` is the structure containing the information about the
|
||||
// TPUReplicate node inputs and outputs.
|
||||
// * `num_tasks` is the number of TensorFlow tasks in the slice.
|
||||
// * `num_cores_per_replica` is the number of cores which are dedicated to
|
||||
// each replica.
|
||||
// * `replicate_node` is the original TPUReplicate node.
|
||||
// * `arg_types` are the types of the arguments to the computation function
|
||||
// passed as argument to TPUReplicate, including per-replica,
|
||||
// broadcast, and variable arguments.
|
||||
// * `arg_shapes` are the corresponding shapes (and handle types/shapes, if
|
||||
// applicable).
|
||||
// * `arg_shardings` and `retval_shardings` are mappings from
|
||||
// arguments/return indices to shardings, as returned by
|
||||
// `AssignArgsAndRetvalsToCores`.
|
||||
// * `pod_devices` lists the devices to assign to each core of each replica.
|
||||
// * `variable_reads` is a vectors of ReadVariableOp operators, one for each
|
||||
// variable argument to the computation.
|
||||
// * The execute operators will have a control edge from
|
||||
// `control_predecessor` and another control edge to `control_successor`.
|
||||
// Populates '*variable_writes' with information about variable values to
|
||||
// write back.
|
||||
static Status BuildExecuteNodes(
|
||||
const ParameterInfo& params_info, int num_tasks,
|
||||
int num_cores_per_replica, const Node& replicate_node,
|
||||
const DataTypeVector& arg_types,
|
||||
const std::vector<InferredShape>& arg_shapes,
|
||||
const DataTypeVector& retval_types,
|
||||
const std::vector<::xla::OpSharding>& arg_shardings,
|
||||
const std::vector<::xla::OpSharding>& retval_shardings,
|
||||
const std::vector<std::vector<string>>& tpu_device_names,
|
||||
Node* compile_node, const std::vector<Node*>& variable_reads,
|
||||
Node* control_predecessor, Node* control_successor,
|
||||
std::vector<VariableWrite>* variable_writes, Graph* graph);
|
||||
|
||||
// Connects the compile node to all the host transfer nodes, and removes the
|
||||
// key placeholder node that was previously standing in for it.
|
||||
// Arguments:
|
||||
// * `compile_node` is the TPUCompile node that has been added to the graph.
|
||||
// * `key_placeholder_node` is the placeholder node to send the key to all the
|
||||
// host
|
||||
// * transfer nodes in the original graph.
|
||||
// * `graph` is the graph being rewritten.
|
||||
static Status ConnectHostComputeNodes(Node* compile_node,
|
||||
Node* key_placeholder_node,
|
||||
Graph* graph);
|
||||
|
||||
// Map from a Node in an outside_compilation cluster in the original graph to
|
||||
// the list of Nodes, one for each replica, that it is expanded into during
|
||||
// replication.
|
||||
typedef absl::node_hash_map<Node*, std::vector<Node*>> NodeToNodeReplicasMap;
|
||||
|
||||
// Map from the name of an outside_compilation cluster to the model-parallel
|
||||
// core index that the HostCompute Op should be placed on in that cluster.
|
||||
typedef std::map<string, int> HostComputeCoreMap;
|
||||
|
||||
// Map from the name of an outside_compilation cluster to the list of Nodes
|
||||
// that should run on the host for that cluster.
|
||||
typedef std::map<string, std::vector<Node*>> OutsideCompilationNodeMap;
|
||||
|
||||
// Copies the outside_compilation nodes in a cluster to create replica
|
||||
// replica_index.
|
||||
static Status CopyOutsideCompilationNodes(
|
||||
int replica_index, const std::vector<Node*>& outside_compilation_nodes,
|
||||
const DeviceNameUtils::ParsedName& tpu_device,
|
||||
const DeviceNameUtils::ParsedName& partial_device,
|
||||
NodeToNodeReplicasMap* node_images, Graph* graph);
|
||||
|
||||
// Replicates all the nodes in outside_compilation clusters in a compiled
|
||||
// computation.
|
||||
static Status ReplicateOutsideCompilationNodes(
|
||||
const std::vector<std::vector<string>>& tf_device_assignment,
|
||||
const HostComputeCoreMap& host_compute_core,
|
||||
const OutsideCompilationNodeMap& outside_compilation_nodes,
|
||||
NodeToNodeReplicasMap* node_images, Graph* graph);
|
||||
|
||||
// Lifts the edges between original outside_compilation nodes in a cluster
|
||||
// onto their replicas.
|
||||
static Status CopyOutsideCompilationEdges(
|
||||
const std::vector<Node*>& outside_compilation_nodes,
|
||||
const NodeToNodeReplicasMap& node_images,
|
||||
const std::unordered_map<string, Node*> outside_compilation_inputs,
|
||||
Graph* graph);
|
||||
|
||||
// Lifts all the edges in outside_compilation clusters in a compiled
|
||||
// computation to their replicas.
|
||||
static Status ReplicateOutsideCompilationEdges(
|
||||
const OutsideCompilationNodeMap& outside_compilation_nodes,
|
||||
const NodeToNodeReplicasMap& node_images,
|
||||
const std::unordered_map<string, Node*> outside_compilation_inputs,
|
||||
Graph* graph);
|
||||
|
||||
// Removes all the original outside_compilation nodes from the graph,
|
||||
// following replication.
|
||||
static Status RemoveOutsideCompilationNodes(
|
||||
const NodeToNodeReplicasMap& node_images, Graph* graph);
|
||||
|
||||
// Lowers outside compilation functional nodes (If/While/function call).
|
||||
// Otherwise, when we have multiple workers, device placer will not be able to
|
||||
// place nodes if outside compilation has DT_RESOURCE inputs (e.g. a
|
||||
// DT_RESOURCE input fed into multiple While nodes on different devices).
|
||||
static Status LowerOutsideCompilationFunctionalNodes(
|
||||
Graph* g, const FunctionLibraryDefinition& flib_def,
|
||||
const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping);
|
||||
|
||||
// Parses the 'host_compute_core' attribute on replicate_node to get the
|
||||
// replicated core id of each outside_compilation cluster.
|
||||
static Status ParseHostComputeCores(
|
||||
const Node& replicate_node,
|
||||
const OutsideCompilationNodeMap& outside_compilation_nodes,
|
||||
HostComputeCoreMap* host_compute_core);
|
||||
|
||||
// Gets the physical topology information about the TPU system.
|
||||
static Status GetDeviceTopology(
|
||||
const DeviceSet& device_set, const Node& replicate_node,
|
||||
int* num_replicas, int* num_cores_per_replica, int* num_tasks,
|
||||
std::vector<std::vector<string>>* tf_device_assignment,
|
||||
std::unique_ptr<xla::DeviceAssignment>* xla_device_assignment,
|
||||
string* tpu_compilation_device);
|
||||
|
||||
// Gets the types of args, retvals, and parameters.
|
||||
static Status GetIOTypes(
|
||||
int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr,
|
||||
Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function,
|
||||
std::unique_ptr<Graph>* computation, DataTypeVector* arg_types,
|
||||
DataTypeVector* retval_types, ParameterInfo* params_info);
|
||||
|
||||
// Find known constants and deals with variable reads.
|
||||
static Status DealWithConstantsAndVariables(
|
||||
const Node& replicate_node, const NameRangeMap& input_name_map,
|
||||
Graph* graph, Node* host_transfer_sequencer, Node* control_before,
|
||||
Node* control_after, absl::Span<const VariableInput> variable_nodes,
|
||||
std::vector<Node*>* guaranteed_constant_nodes,
|
||||
std::vector<Node*>* variable_reads);
|
||||
|
||||
// Adds NoOp nodes for sequencing computation and variable reads/writes.
|
||||
static Status BuildSequencingNodes(const string& tpu_compilation_device,
|
||||
const Node& replicate_node, Graph* graph,
|
||||
Node** host_transfer_sequencer,
|
||||
Node** control_before,
|
||||
Node** control_after);
|
||||
|
||||
// Performs the pass's rewrite on a TPUReplicate node `node`.
|
||||
static Status RewriteTPUReplicateNode(
|
||||
const string& session_handle, const DeviceSet& device_set,
|
||||
Node* replicate_node, FunctionLibraryDefinition* flib_def,
|
||||
FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node,
|
||||
const OutsideCompilationNodeMap& outside_compilation_nodes,
|
||||
const std::vector<Node*>& head_tail_outside_compilation_nodes,
|
||||
NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph,
|
||||
const GraphShapeInfo& shape_info,
|
||||
TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping,
|
||||
int64 autotuner_thresh);
|
||||
|
||||
// Performs host training loop optimization. For example, when TPUExecute
|
||||
// node is inside a while loop, then model weight variables can be sharded
|
||||
// in XLA preferred layout and then unsharded only at the very last iteration
|
||||
// to reduce the number of all_gather.
|
||||
static Status PerformHostTrainingLoopOptimization(
|
||||
Graph* graph, FunctionLibraryDefinition* flib_def,
|
||||
FunctionLibraryRuntime* flr);
|
||||
|
||||
// Heuristically place some nodes with unassigned devices on TPUs for
|
||||
// performance reasons.
|
||||
static Status PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph* graph);
|
||||
|
||||
// Updates the head and tail outside compiled nodes so that nodes have the
|
||||
// correct device and removes the replication and outside compilation
|
||||
// attributes so that these nodes do not trigger further graph optimization
|
||||
// passes.
|
||||
static Status UpdateHeadTailOutsideCompilation(
|
||||
const std::vector<std::vector<string>>& tf_device_assignment,
|
||||
const std::vector<Node*>& head_tail_outside_compilation_nodes);
|
||||
|
||||
private:
|
||||
static bool distribute_vars_;
|
||||
static bool replicate_inputs_outputs_by_default_for_xla_spmd_;
|
||||
static bool enable_cross_replica_sharding_mirrored_variables_;
|
||||
static bool enable_automatic_model_parallelism_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_
|
@ -0,0 +1,45 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "absl/random/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static int64 overridden_node_id = -1;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
|
||||
void OverrideNodeIdForTesting(const int64 node_id) {
|
||||
overridden_node_id = node_id;
|
||||
}
|
||||
|
||||
uint64 GetNodeId() {
|
||||
if (overridden_node_id > -1) {
|
||||
return overridden_node_id;
|
||||
} else {
|
||||
return absl::Uniform(absl::SharedBitGen(), uint64{0},
|
||||
std::numeric_limits<uint64>::max());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
|
||||
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Implementation details of distributed_tpu_rewrite_pass.cc, please DO NOT
|
||||
// depend on these.
|
||||
namespace internal {
|
||||
|
||||
// When set to a value >= 0, overrides the node_id. Used for getting
|
||||
// deterministic node_ids during testing.
|
||||
void OverrideNodeIdForTesting(int64 node_id);
|
||||
|
||||
// Retrieves the node id, used to make some node names unique in the rewrite
|
||||
// pass.
|
||||
uint64 GetNodeId();
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_
|
@ -0,0 +1,629 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
|
||||
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/node_hash_set.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kDefaultShardingValue[] = "";
|
||||
|
||||
const Edge* FindEdgeConnecting(const Node* src, const Node* dst) {
|
||||
for (const auto e : src->out_edges()) {
|
||||
if (e->dst()->name() == dst->name()) return &(*e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Contains TPUExecute node and its DT_RESOURCE input nodes that
|
||||
// correspond to model weights.
|
||||
struct ExecuteNodeInfo {
|
||||
Node* execute_node;
|
||||
std::vector<const Edge*> var_inputs;
|
||||
};
|
||||
|
||||
// Returns whether `node` is in `execute_nodes` or `(identity) -> execute`.
|
||||
bool IsExecuteNodeOrIdentityToExecuteNode(
|
||||
const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
|
||||
const absl::flat_hash_set<Node*>& execute_nodes, Node* node) {
|
||||
if (execute_nodes.find(node) != execute_nodes.end()) return true;
|
||||
if (loop_nodes.find(node) == loop_nodes.end()) return false;
|
||||
if (node->IsNextIteration()) return true;
|
||||
if (!node->IsIdentity()) return false;
|
||||
|
||||
for (const Edge* e : node->out_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
|
||||
Node* node = e->dst();
|
||||
if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
|
||||
node)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// From input node to the TPUExecute op, finds the corresponding Enter node
|
||||
// by searching/traversing nodes in below pattern of nodes:
|
||||
// Enter ----> (identity) ---> While body input
|
||||
// Returns nullptr if the Enter node is not found.
|
||||
xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) {
|
||||
Node* node = input_node;
|
||||
while (node->IsIdentity()) {
|
||||
TF_RETURN_IF_ERROR(node->input_node(0, &node));
|
||||
}
|
||||
|
||||
if (node->IsEnter()) {
|
||||
return node;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop(
|
||||
const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
|
||||
const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) {
|
||||
for (const Edge* output_edge : enter_node->out_edges()) {
|
||||
Node* output_node = output_edge->dst();
|
||||
if (output_edge->IsControlEdge() || output_node->IsExit()) continue;
|
||||
|
||||
// If output node is not execute node, it must be output node
|
||||
// to the while loop body.
|
||||
if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
|
||||
output_node)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Given a TPUCompile node, find all TPUExecute nodes that executes the compiled
|
||||
// program and its model weight variable inputs as well.
|
||||
// TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata`
|
||||
// if new reshard ops are added.
|
||||
Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph,
|
||||
const std::unordered_set<Node*>& loop_nodes, // NOLINT
|
||||
std::vector<ExecuteNodeInfo>* execute_node_info,
|
||||
TPUCompileMetadataProto* new_metadata) {
|
||||
string metadata_string;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string));
|
||||
new_metadata->ParsePartialFromString(metadata_string);
|
||||
if (new_metadata->num_cores_per_replica() != 1) {
|
||||
// We do not support model parallelism yet.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
execute_node_info->clear();
|
||||
for (Node* node : compile_node->out_nodes()) {
|
||||
if (node->type_string() == "TPUExecute") {
|
||||
execute_node_info->push_back({node});
|
||||
}
|
||||
}
|
||||
if (execute_node_info->empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas())
|
||||
<< "Number of replicas does not equal number of execute nodes: "
|
||||
<< new_metadata->num_replicas() << " vs " << execute_node_info->size();
|
||||
DataTypeVector arg_types;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(),
|
||||
"Targs", &arg_types));
|
||||
for (int64 i = 0; i < arg_types.size(); ++i) {
|
||||
if (arg_types[i] != DT_RESOURCE) {
|
||||
continue;
|
||||
}
|
||||
const auto sharding_config = new_metadata->args(i).enable_xla_sharding();
|
||||
if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE &&
|
||||
sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) {
|
||||
continue;
|
||||
}
|
||||
std::vector<const Edge*> edges(execute_node_info->size());
|
||||
bool is_supported = true;
|
||||
std::unordered_map<Node*, absl::flat_hash_set<Node*>>
|
||||
enter_to_execute_nodes;
|
||||
for (int64 j = 0; j < edges.size(); ++j) {
|
||||
auto execute = (*execute_node_info)[j].execute_node;
|
||||
TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j]));
|
||||
TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) ==
|
||||
arg_types[i])
|
||||
<< "Execute op has an unexpected input type.";
|
||||
// Traverse backwards to find the Enter node from which the input is
|
||||
// passed.
|
||||
// This makes sure that we are checking the usages of all potential
|
||||
// aliases of the input node as well.
|
||||
TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput(
|
||||
edges[j]->src()));
|
||||
if (enter_node == nullptr) {
|
||||
is_supported = false;
|
||||
enter_to_execute_nodes.clear();
|
||||
break;
|
||||
}
|
||||
enter_to_execute_nodes[enter_node].insert(edges[j]->dst());
|
||||
}
|
||||
|
||||
for (const auto& it : enter_to_execute_nodes) {
|
||||
// Size of execute nodes should be either 1 (per-replica variables) or
|
||||
// num_replicas (distributed variables).
|
||||
if ((it.second.size() != 1) &&
|
||||
(it.second.size() != new_metadata->num_replicas())) {
|
||||
is_supported = false;
|
||||
break;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(bool no_other_use,
|
||||
ResourceOnlyUsedForTPUExecuteInLoop(
|
||||
graph, loop_nodes, it.first, it.second));
|
||||
if (!no_other_use) {
|
||||
is_supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the variable input edges only when they are supported for all
|
||||
// executes.
|
||||
if (is_supported) {
|
||||
for (int64 j = 0; j < edges.size(); ++j) {
|
||||
(*execute_node_info)[j].var_inputs.push_back(edges[j]);
|
||||
}
|
||||
new_metadata->mutable_args(i)->set_enable_xla_sharding(
|
||||
TPUCompileMetadataProto::Arg::ALLOWED);
|
||||
}
|
||||
}
|
||||
|
||||
int64 total = 0;
|
||||
for (const auto& a : new_metadata->args()) {
|
||||
if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) {
|
||||
total++;
|
||||
}
|
||||
}
|
||||
TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size())
|
||||
<< " total " << total << " var_inputs "
|
||||
<< (*execute_node_info)[0].var_inputs.size();
|
||||
if (total == 0) {
|
||||
// We don't need to process anything if no input is added.
|
||||
execute_node_info->clear();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; }
|
||||
|
||||
void FindTPUCompileNodes(
|
||||
const std::string* current_function_name,
|
||||
const AttrValueMap* current_function_attr,
|
||||
const std::unordered_map<string, WhileLoopFrame>& frames,
|
||||
std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
|
||||
// Adds frames with no children (i.e., the innermost frames) to a worklist.
|
||||
std::deque<const WhileLoopFrame*> worklist;
|
||||
|
||||
for (auto& frame : frames) {
|
||||
if (frame.second.num_children == 0) {
|
||||
worklist.push_back(&frame.second);
|
||||
}
|
||||
}
|
||||
|
||||
// Check TPUCompile node from the innermost while loop to the outermost
|
||||
// while loop.
|
||||
while (!worklist.empty()) {
|
||||
const WhileLoopFrame* frame = worklist.front();
|
||||
worklist.pop_front();
|
||||
|
||||
for (const auto& n : frame->nodes) {
|
||||
if (!IsTPUCompileOp(*n)) continue;
|
||||
|
||||
HostTrainingLoopInfo host_training_loop_info;
|
||||
host_training_loop_info.compile_node_name = n->name();
|
||||
host_training_loop_info.loop_cond_node_name = frame->loop_cond->name();
|
||||
host_training_loop_info.while_loop_name = frame->name;
|
||||
|
||||
for (const auto arg : frame->args) {
|
||||
LoopArgInfo arg_info;
|
||||
arg_info.enter_node_name = arg.enter->name();
|
||||
if (arg.exit) arg_info.exit_node_name = arg.exit->name();
|
||||
|
||||
host_training_loop_info.loop_arguments.push_back(std::move(arg_info));
|
||||
}
|
||||
host_training_loop_info.loop_nodes = frame->nodes;
|
||||
|
||||
if (current_function_name) {
|
||||
host_training_loop_info.encapsulating_function_name =
|
||||
*current_function_name;
|
||||
}
|
||||
if (current_function_attr) {
|
||||
host_training_loop_info.encapsulating_function_attrs =
|
||||
*current_function_attr;
|
||||
}
|
||||
|
||||
host_training_loops_info->emplace_back(
|
||||
std::move(host_training_loop_info));
|
||||
}
|
||||
|
||||
// If the parent has no remaining children, add it to the worklist.
|
||||
--frame->parent->num_children;
|
||||
if (frame->parent->num_children == 0) {
|
||||
worklist.push_back(frame->parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// From while loop cond node, finds all loop exit nodes by searching/traversing
|
||||
// nodes in below pattern of nodes:
|
||||
// LoopCond -----> Switch -----> Exit
|
||||
std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) {
|
||||
std::vector<Node*> loop_exit_nodes;
|
||||
for (const auto e_cond : loop_cond.out_edges()) {
|
||||
if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue;
|
||||
auto switch_node = e_cond->dst();
|
||||
|
||||
for (const auto e_switch : switch_node->out_edges()) {
|
||||
if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue;
|
||||
|
||||
loop_exit_nodes.push_back(e_switch->dst());
|
||||
}
|
||||
}
|
||||
return loop_exit_nodes;
|
||||
}
|
||||
|
||||
// Find any one of switch nodes in the while loop by traversing the graph
|
||||
// from while loop condition node.
|
||||
xla::StatusOr<Node*> GetLoopSwitchNode(const Node& loop_cond_node) {
|
||||
Node* loop_switch_node;
|
||||
for (auto n : loop_cond_node.out_nodes()) {
|
||||
if (n->IsSwitch()) {
|
||||
loop_switch_node = n;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
TF_RET_CHECK(loop_switch_node->IsSwitch())
|
||||
<< "Unable to find any switch nodes.";
|
||||
return loop_switch_node;
|
||||
}
|
||||
|
||||
// Returns or creates a node in that is executed before each loop iteration
|
||||
// in the while loop.
|
||||
Status GetOrCreateBeforeEachIterationNode(Graph* graph, Node* loop_switch_node,
|
||||
Node** node_out) {
|
||||
// If while loop switch node already has a outgoing data to true brach
|
||||
// of the switch op, then reuse that node.
|
||||
for (const auto out_edge : loop_switch_node->out_edges()) {
|
||||
if (out_edge->src_output() == 1) {
|
||||
*node_out = out_edge->dst();
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// Create Identity node that represents execution at every loop iteration.
|
||||
NodeDef at_loop_iteration_nodedef;
|
||||
at_loop_iteration_nodedef.set_op("Identity");
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
|
||||
|
||||
AddNodeAttr("T", dtype, &at_loop_iteration_nodedef);
|
||||
at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat(
|
||||
"TPUVariableReshard/before_iteration", "/_", internal::GetNodeId())));
|
||||
|
||||
Status status;
|
||||
Node* at_loop_iteration_node =
|
||||
graph->AddNode(at_loop_iteration_nodedef, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0);
|
||||
*node_out = at_loop_iteration_node;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Injects NoOp node in that is executed after the very last iteration
|
||||
// of the while loop but before the while loop exit node.
|
||||
Status AddNoOpAfterLastIteration(Graph* graph, Node* loop_switch_node,
|
||||
Node** node_out) {
|
||||
// Find the exit node from loop switch node.
|
||||
Node* exit_node;
|
||||
for (const auto out_node : loop_switch_node->out_nodes()) {
|
||||
if (out_node->IsExit()) {
|
||||
exit_node = out_node;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
TF_RET_CHECK(exit_node != nullptr)
|
||||
<< "Cannot find exit node connected to switch node :"
|
||||
<< loop_switch_node->name();
|
||||
|
||||
// Create NoOp that represents execution at the end of while loop
|
||||
// last iteration.
|
||||
NodeDef after_last_loop_iteration;
|
||||
after_last_loop_iteration.set_op("Identity");
|
||||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
|
||||
|
||||
AddNodeAttr("T", dtype, &after_last_loop_iteration);
|
||||
after_last_loop_iteration.set_name(graph->NewName(strings::StrCat(
|
||||
"TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
|
||||
|
||||
Status status;
|
||||
Node* after_last_iteration_node =
|
||||
graph->AddNode(after_last_loop_iteration, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
// Newly created node must be executed once after last iteration of the while
|
||||
// loop and before while loop exits.
|
||||
graph->AddEdge(loop_switch_node, 0, after_last_iteration_node, 0);
|
||||
graph->AddControlEdge(after_last_iteration_node, exit_node);
|
||||
*node_out = after_last_iteration_node;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status DetectHostTrainingLoop(
|
||||
const std::string* current_function_name,
|
||||
const AttrValueMap* current_function_attr,
|
||||
const FunctionLibraryDefinition* library, Graph* graph,
|
||||
FunctionLibraryRuntime* flr,
|
||||
std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
|
||||
std::vector<AssociatedFunctionInfo> associated_function_list;
|
||||
for (const auto* n : graph->nodes()) {
|
||||
const auto associated_functions = GetAssociatedFunctions(*n, library);
|
||||
if (associated_functions.empty()) continue;
|
||||
|
||||
associated_function_list.insert(associated_function_list.end(),
|
||||
associated_functions.begin(),
|
||||
associated_functions.end());
|
||||
}
|
||||
|
||||
Status ret_status = Status::OK();
|
||||
for (const auto& function : associated_function_list) {
|
||||
if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue;
|
||||
|
||||
// Convert the function to Graph.
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(),
|
||||
AttrSlice(&function.attrs()), &handle));
|
||||
auto cleanup_handle = gtl::MakeCleanup([&]() {
|
||||
auto s = flr->ReleaseHandle(handle);
|
||||
if (!s.ok()) {
|
||||
ret_status.Update(s);
|
||||
}
|
||||
});
|
||||
const FunctionBody* body = flr->GetFunctionBody(handle);
|
||||
Graph* function_graph = body->graph;
|
||||
TF_RETURN_IF_ERROR(DetectHostTrainingLoop(
|
||||
&function.func_name(), &function.attrs(), library, function_graph, flr,
|
||||
host_training_loops_info));
|
||||
}
|
||||
|
||||
// BuildControlFlowInfo() requires that the graph's source node is connected
|
||||
// to all source nodes in the graph. Many graphs violate this invariant.
|
||||
// As so, add edges to source/sink nodes so that this invariant is kept.
|
||||
FixupSourceAndSinkEdges(graph);
|
||||
std::vector<ControlFlowInfo> cf_info;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr));
|
||||
|
||||
std::unordered_map<string, WhileLoopFrame> frames;
|
||||
TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
|
||||
FindTPUCompileNodes(current_function_name, current_function_attr, frames,
|
||||
host_training_loops_info);
|
||||
return ret_status;
|
||||
}
|
||||
|
||||
Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
|
||||
const auto& compile_node_name = host_loop_info.compile_node_name;
|
||||
const auto node_name_map = graph->BuildNodeNameIndex();
|
||||
const auto node_it = node_name_map.find(compile_node_name);
|
||||
TF_RET_CHECK(node_it != node_name_map.end())
|
||||
<< "Unable to find compile node : " << compile_node_name;
|
||||
|
||||
const auto compile_node = node_it->second;
|
||||
std::vector<ExecuteNodeInfo> execute_nodes_info;
|
||||
|
||||
Status status;
|
||||
TPUCompileMetadataProto metadata;
|
||||
status =
|
||||
ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes,
|
||||
&execute_nodes_info, &metadata);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Encountered error when trying to extract execute nodes, "
|
||||
"skipping host loop optimization. Status: "
|
||||
<< status.ToString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (execute_nodes_info.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Update the TPUCompileMetadata such that sharding config of the
|
||||
// sharded resource variable inputs is set to ALLOWED instead of
|
||||
// TENTATIVE.
|
||||
string new_metadata_string;
|
||||
metadata.SerializeToString(&new_metadata_string);
|
||||
compile_node->ClearAttr("metadata");
|
||||
compile_node->AddAttr("metadata", new_metadata_string);
|
||||
|
||||
// Unsharding of the model weight variables must happen only at the very
|
||||
// last loop iteration. As so, add while loop condition predicate as an
|
||||
// input to the sharding switch node. If loop condition is true, we do not
|
||||
// unshard.
|
||||
const auto& cond_node_name = host_loop_info.loop_cond_node_name;
|
||||
auto loop_cond_node_it = node_name_map.find(cond_node_name);
|
||||
TF_RET_CHECK(loop_cond_node_it != node_name_map.end())
|
||||
<< "Cannot find loop condition node : " << cond_node_name;
|
||||
auto* loop_condition_node = loop_cond_node_it->second;
|
||||
|
||||
// In order to make sure that shard/unshard operations are invoked
|
||||
// at the start of every loop body and at the end of last iteration
|
||||
// of the loop, respectively, traverse the graph and find a switch node
|
||||
// of the host training loop.
|
||||
TF_ASSIGN_OR_RETURN(Node * switch_node,
|
||||
GetLoopSwitchNode(*loop_condition_node));
|
||||
|
||||
Node* after_last_iteration_node;
|
||||
TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(graph, switch_node,
|
||||
&after_last_iteration_node));
|
||||
|
||||
Node* before_loop_iteration_node;
|
||||
TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode(
|
||||
graph, switch_node, &before_loop_iteration_node));
|
||||
|
||||
// Create const op that represents default sharding value
|
||||
// (i.e. no-op sharding).
|
||||
NodeDef default_sharding;
|
||||
default_sharding.set_op("Const");
|
||||
default_sharding.set_name(graph->NewName(strings::StrCat(
|
||||
"TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
|
||||
AddNodeAttr("dtype", DT_STRING, &default_sharding);
|
||||
|
||||
Tensor t(DT_STRING, {2});
|
||||
t.vec<tstring>()(0) = kDefaultShardingValue;
|
||||
t.vec<tstring>()(1) = kDefaultShardingValue;
|
||||
t.AsProtoTensorContent(
|
||||
(*default_sharding.mutable_attr())["value"].mutable_tensor());
|
||||
|
||||
Node* default_sharding_node = graph->AddNode(default_sharding, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
// Add control edge between loop condition to make sure that
|
||||
// default_sharding_node node is inside the while loop frame.
|
||||
graph->AddControlEdge(loop_condition_node, default_sharding_node);
|
||||
|
||||
// Build a no-op node used to add control edges after unshard nodes.
|
||||
NodeDef after_unshard;
|
||||
after_unshard.set_op("NoOp");
|
||||
after_unshard.set_name(graph->NewName(strings::StrCat(
|
||||
"TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
|
||||
auto after_unshard_node = graph->AddNode(after_unshard, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
for (auto info : execute_nodes_info) {
|
||||
auto execute_node = info.execute_node;
|
||||
// Create Reshard op that optionally shards model weight variables
|
||||
// prior to program execution.
|
||||
NodeDef reshard_node_def;
|
||||
reshard_node_def.set_name(graph->NewName(strings::StrCat(
|
||||
"TPUVariableReshard/reshard", "/_", internal::GetNodeId())));
|
||||
reshard_node_def.set_op("TPUReshardVariables");
|
||||
AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
|
||||
&reshard_node_def);
|
||||
Node* reshard_op_node = graph->AddNode(reshard_node_def, &status);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
reshard_op_node->set_assigned_device_name(
|
||||
execute_node->assigned_device_name());
|
||||
|
||||
// Reshard op must execute at every loop iteration prior to
|
||||
// TPUExecute node.
|
||||
graph->AddControlEdge(before_loop_iteration_node, reshard_op_node);
|
||||
graph->AddControlEdge(reshard_op_node, execute_node);
|
||||
|
||||
for (int i = 0; i < info.var_inputs.size(); ++i) {
|
||||
const auto variable_edge = info.var_inputs[i];
|
||||
graph->AddEdge(variable_edge->src(), variable_edge->src_output(),
|
||||
reshard_op_node, i);
|
||||
}
|
||||
|
||||
const int new_key_input = info.var_inputs.size();
|
||||
// Add program input edge from the compiler(i.e. compilation key).
|
||||
const auto compilation_key_edge =
|
||||
FindEdgeConnecting(compile_node, execute_node);
|
||||
graph->AddEdge(compile_node, compilation_key_edge->src_output(),
|
||||
reshard_op_node, new_key_input);
|
||||
|
||||
// Create VarHandleOp to store sharding state. Sharding state holds string
|
||||
// compilation key that identifies whether the graph is re-compiled and the
|
||||
// variables need to be sharded again.
|
||||
NodeDef var_handle_def;
|
||||
var_handle_def.set_op("VarHandleOp");
|
||||
var_handle_def.set_name(graph->NewName(strings::StrCat(
|
||||
"TPUVariableReshard/reshard_state", "/_", internal::GetNodeId())));
|
||||
AddNodeAttr("dtype", DT_STRING, &var_handle_def);
|
||||
AddNodeAttr("shape", TensorShape({}), &var_handle_def);
|
||||
Node* var_handle_node = graph->AddNode(var_handle_def, &status);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
// Add control edge between `var_handle_def` node and while loop
|
||||
// loop condition so that `var_handle_def` is inside the same while loop
|
||||
// frame.
|
||||
// TODO(hongjunchoi): Consider adding control edge from another node--such
|
||||
// as input control node.
|
||||
graph->AddControlEdge(loop_condition_node, var_handle_node);
|
||||
|
||||
// Connect data edge between var handle op and reshard op.
|
||||
const int format_state_input = new_key_input + 1;
|
||||
graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input);
|
||||
|
||||
// Create Reshard op that represents unsharding after TPUExecute.
|
||||
NodeDef unshard_node_def;
|
||||
unshard_node_def.set_name(graph->NewName(strings::StrCat(
|
||||
"TPUVariableReshard/unshard", "/_", internal::GetNodeId())));
|
||||
unshard_node_def.set_op("TPUReshardVariables");
|
||||
AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
|
||||
&unshard_node_def);
|
||||
Node* unshard_op_node = graph->AddNode(unshard_node_def, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
|
||||
unshard_op_node->set_assigned_device_name(
|
||||
execute_node->assigned_device_name());
|
||||
|
||||
for (int i = 0; i < info.var_inputs.size(); ++i) {
|
||||
const auto variable_edge = info.var_inputs[i];
|
||||
// Connect model weight resource variables to unshard op. Since unshard op
|
||||
// must be only invoked after the very last loop iteration, for each while
|
||||
// loop inputs, we traverse backwards to find the switch node of the host
|
||||
// training loop and connect `output_false` field of the switch node with
|
||||
// unshard op.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Node * enter_node,
|
||||
FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src()));
|
||||
graph->AddEdge(enter_node, 0, unshard_op_node, i);
|
||||
}
|
||||
|
||||
// Add control dependency before/after unshard node and the control nodes.
|
||||
graph->AddControlEdge(after_last_iteration_node, unshard_op_node);
|
||||
graph->AddControlEdge(unshard_op_node, after_unshard_node);
|
||||
|
||||
graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input);
|
||||
|
||||
// Add data edge from sharding state var handle op to unshard op.
|
||||
graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input);
|
||||
}
|
||||
// Add control dependency from after_unshard_node to all exits nodes. This is
|
||||
// to make sure that the unshard ops will be executed as long as any of the
|
||||
// exits are used.
|
||||
for (auto exit : FindLoopExitNodes(*loop_condition_node)) {
|
||||
graph->AddControlEdge(after_unshard_node, exit);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
@ -0,0 +1,80 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
|
||||
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
struct LoopArgInfo {
|
||||
std::string enter_node_name;
|
||||
// Exit nodes are optional for loop invariant while loop args.
|
||||
absl::optional<std::string> exit_node_name;
|
||||
};
|
||||
|
||||
struct HostTrainingLoopInfo {
|
||||
// Name and attribute information about the function in which
|
||||
// host training loop is included. If host training loop is not
|
||||
// inside a function call, then `function_name` and `function_attrs`
|
||||
// are nullopt.
|
||||
absl::optional<std::string> encapsulating_function_name;
|
||||
absl::optional<AttrValueMap> encapsulating_function_attrs;
|
||||
|
||||
// TPU Compile node as within a host training loop.
|
||||
std::string compile_node_name;
|
||||
|
||||
// Name of the while loop in which TPU compile op is located.
|
||||
std::string while_loop_name;
|
||||
|
||||
// Name of the node that represents loop condition.
|
||||
std::string loop_cond_node_name;
|
||||
|
||||
// Exit and Enter node names for each loop arguments.
|
||||
std::vector<LoopArgInfo> loop_arguments;
|
||||
|
||||
std::unordered_set<Node*> loop_nodes; // NOLINT
|
||||
};
|
||||
|
||||
// Walks through the `graph`, recursively if functional nodes exist, and
|
||||
// identifies all host training loops. Host training loops are the inner
|
||||
// most while loops that encapsulates TPUCompileOp node. This would be
|
||||
// later used/analyzed to inroduce host loop specific optimizations such
|
||||
// as adding sharded weight update.
|
||||
Status DetectHostTrainingLoop(
|
||||
const std::string* current_function_name,
|
||||
const AttrValueMap* current_function_attr,
|
||||
const FunctionLibraryDefinition* library, Graph* graph,
|
||||
FunctionLibraryRuntime* flr,
|
||||
std::vector<HostTrainingLoopInfo>* host_training_loops_info);
|
||||
|
||||
// Injects VariableReshardOps to before and after TPUExecute op inside
|
||||
// host training loop body. This effectively applies sharded weight update
|
||||
// on model weight variables.
|
||||
Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info);
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_
|
@ -0,0 +1,73 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
IncompleteNodeDefBuilder::IncompleteNodeDefBuilder(const string& name,
|
||||
const string& op,
|
||||
const NodeDebugInfo& debug) {
|
||||
nodedef_.set_name(name);
|
||||
nodedef_.set_op(op);
|
||||
MergeDebugInfo(debug, &nodedef_);
|
||||
}
|
||||
|
||||
IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::AddAttr(
|
||||
const string& attr, const DataType& type) {
|
||||
AddNodeAttr(attr, type, &nodedef_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::AddAttr(const string& attr,
|
||||
int val) {
|
||||
AddNodeAttr(attr, val, &nodedef_);
|
||||
return *this;
|
||||
}
|
||||
|
||||
IncompleteNodeDefBuilder& IncompleteNodeDefBuilder::Device(
|
||||
const string& device) {
|
||||
nodedef_.set_device(device);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Status IncompleteNodeDefBuilder::Build(Graph* graph, Node** n) {
|
||||
Status status;
|
||||
*n = graph->AddNode(nodedef_, &status);
|
||||
return status;
|
||||
}
|
||||
|
||||
IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Identity(
|
||||
const string& name, const DataType& type, const NodeDebugInfo& debug) {
|
||||
return IncompleteNodeDefBuilder(name, "Identity", debug).AddAttr("T", type);
|
||||
}
|
||||
|
||||
IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Merge(
|
||||
const string& name, const DataType& type, const NodeDebugInfo& debug,
|
||||
int n) {
|
||||
return IncompleteNodeDefBuilder(name, "Merge", debug)
|
||||
.AddAttr("T", type)
|
||||
.AddAttr("N", n);
|
||||
}
|
||||
|
||||
IncompleteNodeDefBuilder IncompleteNodeDefBuilder::Switch(
|
||||
const string& name, const DataType& type, const NodeDebugInfo& debug) {
|
||||
return IncompleteNodeDefBuilder(name, "Switch", debug).AddAttr("T", type);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,58 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_
|
||||
#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Convenience builder to build NodeDefs without specifying the inputs. This is
|
||||
// similar to NodeDefBuilder except inputs are not specified.
|
||||
// TODO(jpienaar): Clean up NodeDefBuilder and remove this class.
|
||||
class IncompleteNodeDefBuilder {
|
||||
public:
|
||||
IncompleteNodeDefBuilder(const string& name, const string& op,
|
||||
const NodeDebugInfo& debug);
|
||||
|
||||
IncompleteNodeDefBuilder& AddAttr(const string& attr, const DataType& type);
|
||||
IncompleteNodeDefBuilder& AddAttr(const string& attr, int val);
|
||||
|
||||
IncompleteNodeDefBuilder& Device(const string& device);
|
||||
|
||||
Status Build(Graph* graph, Node** n);
|
||||
|
||||
static IncompleteNodeDefBuilder Identity(const string& name,
|
||||
const DataType& type,
|
||||
const NodeDebugInfo& debug);
|
||||
static IncompleteNodeDefBuilder Merge(const string& name,
|
||||
const DataType& type,
|
||||
const NodeDebugInfo& debug, int n);
|
||||
static IncompleteNodeDefBuilder Switch(const string& name,
|
||||
const DataType& type,
|
||||
const NodeDebugInfo& debug);
|
||||
|
||||
private:
|
||||
NodeDef nodedef_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_NODEDEF_BUILDER_H_
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h"
|
||||
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h"
|
||||
#include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h"
|
||||
#include "tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h"
|
||||
|
||||
@ -30,8 +31,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 34,
|
||||
EncapsulateTPUComputationsPass);
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 39,
|
||||
ExtractOutsideCompilationPass);
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 40,
|
||||
DistributedTPURewritePass);
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 0,
|
||||
VariableMergerPass);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -55,8 +55,8 @@ class MultiPlatformManagerImpl {
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter)
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
const std::function<bool(const Platform*)>& filter,
|
||||
bool initialize_platform) TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
using Listener = MultiPlatformManager::Listener;
|
||||
port::Status RegisterListener(std::unique_ptr<Listener> listener)
|
||||
@ -188,7 +188,8 @@ port::Status MultiPlatformManagerImpl::RegisterListener(
|
||||
|
||||
port::StatusOr<std::vector<Platform*>>
|
||||
MultiPlatformManagerImpl::PlatformsWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter) {
|
||||
const std::function<bool(const Platform*)>& filter,
|
||||
bool initialize_platform) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
CHECK_EQ(id_map_.size(), name_map_.size());
|
||||
std::vector<Platform*> platforms;
|
||||
@ -196,7 +197,7 @@ MultiPlatformManagerImpl::PlatformsWithFilter(
|
||||
for (const auto& entry : id_map_) {
|
||||
Platform* platform = entry.second;
|
||||
if (filter(platform)) {
|
||||
if (!platform->Initialized()) {
|
||||
if (initialize_platform && !platform->Initialized()) {
|
||||
SE_RETURN_IF_ERROR(platform->Initialize({}));
|
||||
}
|
||||
platforms.push_back(platform);
|
||||
@ -299,7 +300,14 @@ MultiPlatformManager::InitializePlatformWithId(
|
||||
/*static*/ port::StatusOr<std::vector<Platform*>>
|
||||
MultiPlatformManager::PlatformsWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter) {
|
||||
return Impl().PlatformsWithFilter(filter);
|
||||
return PlatformsWithFilter(filter, /*initialize_platform=*/true);
|
||||
}
|
||||
|
||||
/*static*/ port::StatusOr<std::vector<Platform*>>
|
||||
MultiPlatformManager::PlatformsWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter,
|
||||
bool initialize_platform) {
|
||||
return Impl().PlatformsWithFilter(filter, initialize_platform);
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
@ -130,6 +130,10 @@ class MultiPlatformManager {
|
||||
static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter);
|
||||
|
||||
static port::StatusOr<std::vector<Platform*>> PlatformsWithFilter(
|
||||
const std::function<bool(const Platform*)>& filter,
|
||||
bool initialize_platform);
|
||||
|
||||
// Although the MultiPlatformManager "owns" its platforms, it holds them as
|
||||
// undecorated pointers to prevent races during program exit (between this
|
||||
// object's data and the underlying platforms (e.g., CUDA, OpenCL).
|
||||
|
@ -331,6 +331,7 @@ cc_library(
|
||||
name = "tpu_topology_external",
|
||||
srcs = ["tpu_topology.cc"],
|
||||
hdrs = ["tpu_topology.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_decl",
|
||||
"//tensorflow/core/platform:types",
|
||||
|
@ -23,10 +23,11 @@ namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
namespace {
|
||||
TpuPlatformInterface* GetRegisteredPlatformStatic() {
|
||||
TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform) {
|
||||
// Prefer TpuPlatform if it's registered.
|
||||
auto status_or_tpu_platform =
|
||||
stream_executor::MultiPlatformManager::PlatformWithName("TPU");
|
||||
stream_executor::MultiPlatformManager::PlatformWithName(
|
||||
"TPU", initialize_platform);
|
||||
if (status_or_tpu_platform.ok()) {
|
||||
return static_cast<TpuPlatformInterface*>(
|
||||
status_or_tpu_platform.ValueOrDie());
|
||||
@ -43,7 +44,8 @@ TpuPlatformInterface* GetRegisteredPlatformStatic() {
|
||||
[](const stream_executor::Platform* platform) {
|
||||
return dynamic_cast<const TpuPlatformInterface*>(platform) !=
|
||||
nullptr;
|
||||
});
|
||||
},
|
||||
initialize_platform);
|
||||
if (!status_or_other_tpu_platforms.ok()) {
|
||||
LOG(WARNING) << "Error when getting other TPU platforms: "
|
||||
<< status_or_tpu_platform.status();
|
||||
@ -64,9 +66,24 @@ TpuPlatformInterface* GetRegisteredPlatformStatic() {
|
||||
|
||||
/* static */
|
||||
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform() {
|
||||
// Use a local static variable to avoid data races during initialization.
|
||||
return GetRegisteredPlatform(/*initialize_platform=*/true);
|
||||
}
|
||||
|
||||
/* static */
|
||||
TpuPlatformInterface* TpuPlatformInterface::GetRegisteredPlatform(
|
||||
bool initialize_platform) {
|
||||
static bool requested_initialize_platform = initialize_platform;
|
||||
static TpuPlatformInterface* tpu_registered_platform =
|
||||
GetRegisteredPlatformStatic();
|
||||
GetRegisteredPlatformStatic(initialize_platform);
|
||||
|
||||
if (!requested_initialize_platform && initialize_platform) {
|
||||
// If the first time this function is called, we did not request
|
||||
// initializing the platform, but the next caller wants the platform
|
||||
// initialized, we will call GetRegisteredPlatformStatic again to initialize
|
||||
// the platform.
|
||||
tpu_registered_platform = GetRegisteredPlatformStatic(initialize_platform);
|
||||
}
|
||||
|
||||
return tpu_registered_platform;
|
||||
}
|
||||
|
||||
|
@ -33,6 +33,9 @@ class TpuPlatformInterface : public stream_executor::Platform {
|
||||
// is registered or an error occurred.
|
||||
static TpuPlatformInterface* GetRegisteredPlatform();
|
||||
|
||||
// Option to not initialize a platform if not necessary.
|
||||
static TpuPlatformInterface* GetRegisteredPlatform(bool initialize_platform);
|
||||
|
||||
virtual Status Reset() { return Reset(false); }
|
||||
|
||||
virtual Status Reset(bool only_tear_down) = 0;
|
||||
|
@ -30,6 +30,7 @@ struct TpuChipCoordinatesExternal {
|
||||
|
||||
class TpuCoreLocationExternal {
|
||||
public:
|
||||
TpuCoreLocationExternal() : core_location_(nullptr) {}
|
||||
explicit TpuCoreLocationExternal(void* core_location)
|
||||
: core_location_(core_location) {}
|
||||
TpuChipCoordinatesExternal chip_coordinates() const;
|
||||
|
Loading…
Reference in New Issue
Block a user