Add feature to support Maximal/Replicate input sharding for model parallelism.
PiperOrigin-RevId: 298730105 Change-Id: I983bb575b367322e97c6a24803a6c8e7f8892acd
This commit is contained in:
parent
3d5ed682b4
commit
7dcc0a9308
@ -337,6 +337,7 @@ cc_library(
|
||||
":tensorflow_types",
|
||||
":tpu_rewrite_device_util",
|
||||
":translate_utils",
|
||||
":xla_sharding_util",
|
||||
"//tensorflow/compiler/mlir/lite:validators",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||
@ -1154,3 +1155,19 @@ cc_library(
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_sharding_util",
|
||||
srcs = [
|
||||
"utils/xla_sharding_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"utils/xla_sharding_util.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -353,23 +354,17 @@ Operation* BuildCompileOp(
|
||||
compilation_device);
|
||||
}
|
||||
|
||||
// Creates a `tf.TPUExecute` op that executes TPU program generated by
|
||||
// `compile_op`.
|
||||
Operation* BuildExecuteOp(Operation* compile_op,
|
||||
// Creates a `tf.TPUExecute` op that executes TPU program.
|
||||
Operation* BuildExecuteOp(llvm::ArrayRef<Value> inputs,
|
||||
tf_device::LaunchFuncOp launch_func,
|
||||
OpBuilder* builder) {
|
||||
// TPUExecute inherits all launch_func inputs, and takes an additional input
|
||||
// for compilation cache key.
|
||||
llvm::SmallVector<Value, 4> tensor_inputs(launch_func.getOperands());
|
||||
tensor_inputs.push_back(compile_op->getResult(1));
|
||||
|
||||
// TODO(b/139377366): Need to snapshot all resource variable inputs in
|
||||
// follow-up CLs.
|
||||
|
||||
// TPUExecute has same output types as launch_func.
|
||||
return builder->create<TF::TPUExecuteOp>(
|
||||
launch_func.getLoc(), launch_func.getResultTypes(), tensor_inputs,
|
||||
llvm::ArrayRef<NamedAttribute>{});
|
||||
return builder->create<TF::TPUExecuteOp>(launch_func.getLoc(),
|
||||
launch_func.getResultTypes(), inputs,
|
||||
llvm::ArrayRef<NamedAttribute>{});
|
||||
}
|
||||
|
||||
// Creates a tf_device.parallel_execute op that wraps TPUExecute op to
|
||||
@ -394,7 +389,14 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op,
|
||||
auto parallel_execute_op = builder->create<tf_device::ParallelExecuteOp>(
|
||||
launch_func.getLoc(), num_logical_cores, concatenated_output_types);
|
||||
|
||||
// Extract inputs for each region of the parallel_execute op. The i-th
|
||||
// element in the list represents the input lists to TPU computation for
|
||||
// i-th logical core.
|
||||
auto input_list = tensorflow::ExtractInputsForLogicalDevices(
|
||||
num_logical_cores, launch_func);
|
||||
|
||||
// For each logical core, create a region with TPUExecute op.
|
||||
assert(input_list.size() == num_logical_cores);
|
||||
for (int core_id = 0; core_id < num_logical_cores; ++core_id) {
|
||||
auto& region = parallel_execute_op.GetRegionBlockWithIndex(core_id);
|
||||
builder->setInsertionPointToEnd(®ion);
|
||||
@ -404,7 +406,9 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op,
|
||||
// TODO(b/148913294): Identify inputs/return values specific to each
|
||||
// logical core TPU execution by parsing xla_sharding op in
|
||||
// launch_func.
|
||||
auto execute = BuildExecuteOp(compile_op, launch_func, builder);
|
||||
auto execute_inputs = input_list[core_id];
|
||||
execute_inputs.emplace_back(compile_op->getResult(1 + core_id));
|
||||
auto execute = BuildExecuteOp(execute_inputs, launch_func, builder);
|
||||
|
||||
// Create a launch op for each region of parallel_execute.
|
||||
//
|
||||
@ -592,7 +596,10 @@ LogicalResult Rewrite(
|
||||
// TODO(hongjunchoi): Correctly parse TPU topology and assign logical device
|
||||
// attributes to launch_op's within parallel_execute op.
|
||||
} else {
|
||||
Operation* execute_op = BuildExecuteOp(compile_op, launch_func, builder);
|
||||
llvm::SmallVector<Value, 4> execute_inputs(launch_func.getOperands());
|
||||
execute_inputs.emplace_back(compile_op->getResult(1));
|
||||
Operation* execute_op =
|
||||
BuildExecuteOp(execute_inputs, launch_func, builder);
|
||||
tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
|
||||
tpu_device_assignment.execution_devices, replicate, execute_op,
|
||||
builder);
|
||||
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
|
||||
#include "tensorflow/compiler/xla/client/sharding_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -34,10 +35,6 @@ namespace mlir {
|
||||
namespace TFTPU {
|
||||
namespace {
|
||||
|
||||
constexpr char kXlaShardingAttr[] = "_XlaSharding";
|
||||
constexpr char kInputShardingAttr[] = "input_sharding_configuration";
|
||||
constexpr char kOutputShardingAttr[] = "output_sharding_configuration";
|
||||
|
||||
struct TPUShardingIdentificationPass
|
||||
: public ModulePass<TPUShardingIdentificationPass> {
|
||||
void runOnModule() override;
|
||||
@ -68,13 +65,6 @@ void GetAdjacentToXlaShardingOp(
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Optional<StringRef> ParseShardingAttribute(Operation* operation) {
|
||||
const auto& sharding_attr =
|
||||
operation->getAttrOfType<StringAttr>(kXlaShardingAttr);
|
||||
if (!sharding_attr) return llvm::Optional<StringRef>();
|
||||
return sharding_attr.getValue();
|
||||
}
|
||||
|
||||
// Parse XlaSharding op connected to input args. If Input to
|
||||
// tf_device.LaunchFunc op is of resource type, then XlaSharding op
|
||||
// will be connected to following ReadVariable op.
|
||||
@ -97,7 +87,7 @@ llvm::Optional<StringRef> ParseInputSharding(const FuncOp func,
|
||||
}
|
||||
|
||||
if (!parsed_sharding_op) return llvm::Optional<StringRef>();
|
||||
return ParseShardingAttribute(parsed_sharding_op->getOperation());
|
||||
return tensorflow::ParseShardingAttribute(parsed_sharding_op->getOperation());
|
||||
}
|
||||
|
||||
// If operand of return value of tf_device.LaunchFunc op is directly from
|
||||
@ -107,7 +97,7 @@ llvm::Optional<StringRef> ParseReturnValueSharding(FuncOp func,
|
||||
const OpOperand& operand) {
|
||||
if (auto sharding_op =
|
||||
llvm::dyn_cast<TF::XlaShardingOp>(operand.get().getDefiningOp())) {
|
||||
return ParseShardingAttribute(sharding_op.getOperation());
|
||||
return tensorflow::ParseShardingAttribute(sharding_op.getOperation());
|
||||
}
|
||||
|
||||
return llvm::Optional<StringRef>();
|
||||
@ -153,8 +143,8 @@ void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) {
|
||||
if (!input_arg_sharding.hasValue()) continue;
|
||||
sharding_for_args[arg_index] = input_arg_sharding->str();
|
||||
}
|
||||
SetShardingConfigurationAsAttribute(launch_func, kInputShardingAttr,
|
||||
sharding_for_args);
|
||||
SetShardingConfigurationAsAttribute(
|
||||
launch_func, tensorflow::kInputShardingAttr, sharding_for_args);
|
||||
|
||||
// By default return values from logical core 0 is used if no sharding
|
||||
// configuration is defined.
|
||||
@ -176,8 +166,8 @@ void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) {
|
||||
sharding_for_return_values[return_value_index] =
|
||||
return_val_sharding->str();
|
||||
}
|
||||
SetShardingConfigurationAsAttribute(launch_func, kOutputShardingAttr,
|
||||
sharding_for_return_values);
|
||||
SetShardingConfigurationAsAttribute(
|
||||
launch_func, tensorflow::kOutputShardingAttr, sharding_for_return_values);
|
||||
}
|
||||
|
||||
void TPUShardingIdentificationPass::runOnModule() {
|
||||
|
||||
@ -0,0 +1,86 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kXlaShardingAttrName = "_XlaSharding";
|
||||
const char* const kInputShardingAttr = "input_sharding_configuration";
|
||||
const char* const kOutputShardingAttr = "output_sharding_configuration";
|
||||
|
||||
llvm::Optional<mlir::StringRef> ParseShardingAttribute(
|
||||
mlir::Operation* operation) {
|
||||
const auto& sharding_attr =
|
||||
operation->getAttrOfType<mlir::StringAttr>(kXlaShardingAttrName);
|
||||
if (!sharding_attr) return llvm::Optional<mlir::StringRef>();
|
||||
return sharding_attr.getValue();
|
||||
}
|
||||
|
||||
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4>
|
||||
ExtractInputsForLogicalDevices(int num_logical_cores,
|
||||
mlir::tf_device::LaunchFuncOp launch_func) {
|
||||
// Initialize the input list for each logical devices.
|
||||
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
|
||||
input_list.reserve(num_logical_cores);
|
||||
for (int i = 0; i < num_logical_cores; ++i)
|
||||
input_list.emplace_back(llvm::SmallVector<mlir::Value, 4>());
|
||||
|
||||
llvm::SmallVector<mlir::Value, 4> launch_func_inputs(
|
||||
launch_func.getOperands());
|
||||
auto sharding_attrs =
|
||||
launch_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
|
||||
kInputShardingAttr);
|
||||
// If sharding attribute does not exist, then all inputs are placed on 0th
|
||||
// logical core by default.
|
||||
if (!sharding_attrs) {
|
||||
input_list[0] = launch_func_inputs;
|
||||
return input_list;
|
||||
}
|
||||
|
||||
// Enumerate sharding configuration for each inputs. If input has replicate
|
||||
// sharding then all logical devices take the value as input. If input has
|
||||
// maximal sharding then only the specified logical device take the value as
|
||||
// the input.
|
||||
for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) {
|
||||
const auto& sharding_attr = sharding_attr_and_index.value();
|
||||
const auto input_index = sharding_attr_and_index.index();
|
||||
const auto& input_value = launch_func_inputs[input_index];
|
||||
|
||||
xla::OpSharding sharding;
|
||||
sharding.ParseFromString(
|
||||
sharding_attr.cast<mlir::StringAttr>().getValue().str());
|
||||
|
||||
const auto input_sharing_type = sharding.type();
|
||||
if (input_sharing_type == xla::OpSharding::OTHER)
|
||||
launch_func.emitError(
|
||||
"tiled inputs are not yet supported for model parallelism");
|
||||
|
||||
if (input_sharing_type == xla::OpSharding::REPLICATED) {
|
||||
for (auto inputs : input_list) inputs.emplace_back(input_value);
|
||||
} else {
|
||||
assert(input_sharing_type == xla::OpSharding::MAXIMAL);
|
||||
const int logical_device_id = sharding.tile_assignment_devices(0);
|
||||
input_list[logical_device_id].emplace_back(input_value);
|
||||
}
|
||||
}
|
||||
return input_list;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
@ -0,0 +1,44 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_
|
||||
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
extern const char* const kXlaShardingAttrName;
|
||||
extern const char* const kInputShardingAttr;
|
||||
extern const char* const kOutputShardingAttr;
|
||||
|
||||
// Parse "_XlaSharding" attribute from operation, if it exists.
|
||||
llvm::Optional<mlir::StringRef> ParseShardingAttribute(
|
||||
mlir::Operation* operation);
|
||||
|
||||
// Parses "input_sharding_configuration" attribute and returns a list where
|
||||
// i-th element is a list of mlir::Value's which represent inputs for the
|
||||
// TPU computation correponding to i-th logical device. If the attribute
|
||||
// does not exist, the all inputs are placed on logical core 0.
|
||||
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4>
|
||||
ExtractInputsForLogicalDevices(int num_logical_cores,
|
||||
mlir::tf_device::LaunchFuncOp launch_func);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_
|
||||
Loading…
x
Reference in New Issue
Block a user