Add feature to support Maximal/Replicate input sharding for model parallelism.

PiperOrigin-RevId: 298730105
Change-Id: I983bb575b367322e97c6a24803a6c8e7f8892acd
This commit is contained in:
A. Unique TensorFlower 2020-03-03 17:14:20 -08:00 committed by TensorFlower Gardener
parent 3d5ed682b4
commit 7dcc0a9308
5 changed files with 174 additions and 30 deletions

View File

@ -337,6 +337,7 @@ cc_library(
":tensorflow_types",
":tpu_rewrite_device_util",
":translate_utils",
":xla_sharding_util",
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla:xla_proto_cc",
@ -1154,3 +1155,19 @@ cc_library(
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "xla_sharding_util",
srcs = [
"utils/xla_sharding_util.cc",
],
hdrs = [
"utils/xla_sharding_util.h",
],
deps = [
":tensorflow",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
@ -353,23 +354,17 @@ Operation* BuildCompileOp(
compilation_device);
}
// Creates a `tf.TPUExecute` op that executes TPU program generated by
// `compile_op`.
Operation* BuildExecuteOp(Operation* compile_op,
// Creates a `tf.TPUExecute` op that executes TPU program.
Operation* BuildExecuteOp(llvm::ArrayRef<Value> inputs,
tf_device::LaunchFuncOp launch_func,
OpBuilder* builder) {
// TPUExecute inherits all launch_func inputs, and takes an additional input
// for compilation cache key.
llvm::SmallVector<Value, 4> tensor_inputs(launch_func.getOperands());
tensor_inputs.push_back(compile_op->getResult(1));
// TODO(b/139377366): Need to snapshot all resource variable inputs in
// follow-up CLs.
// TPUExecute has same output types as launch_func.
return builder->create<TF::TPUExecuteOp>(
launch_func.getLoc(), launch_func.getResultTypes(), tensor_inputs,
llvm::ArrayRef<NamedAttribute>{});
return builder->create<TF::TPUExecuteOp>(launch_func.getLoc(),
launch_func.getResultTypes(), inputs,
llvm::ArrayRef<NamedAttribute>{});
}
// Creates a tf_device.parallel_execute op that wraps TPUExecute op to
@ -394,7 +389,14 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op,
auto parallel_execute_op = builder->create<tf_device::ParallelExecuteOp>(
launch_func.getLoc(), num_logical_cores, concatenated_output_types);
// Extract inputs for each region of the parallel_execute op. The i-th
// element in the list represents the input lists to TPU computation for
// i-th logical core.
auto input_list = tensorflow::ExtractInputsForLogicalDevices(
num_logical_cores, launch_func);
// For each logical core, create a region with TPUExecute op.
assert(input_list.size() == num_logical_cores);
for (int core_id = 0; core_id < num_logical_cores; ++core_id) {
auto& region = parallel_execute_op.GetRegionBlockWithIndex(core_id);
builder->setInsertionPointToEnd(&region);
@ -404,7 +406,9 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op,
// TODO(b/148913294): Identify inputs/return values specific to each
// logical core TPU execution by parsing xla_sharding op in
// launch_func.
auto execute = BuildExecuteOp(compile_op, launch_func, builder);
auto execute_inputs = input_list[core_id];
execute_inputs.emplace_back(compile_op->getResult(1 + core_id));
auto execute = BuildExecuteOp(execute_inputs, launch_func, builder);
// Create a launch op for each region of parallel_execute.
//
@ -592,7 +596,10 @@ LogicalResult Rewrite(
// TODO(hongjunchoi): Correctly parse TPU topology and assign logical device
// attributes to launch_op's within parallel_execute op.
} else {
Operation* execute_op = BuildExecuteOp(compile_op, launch_func, builder);
llvm::SmallVector<Value, 4> execute_inputs(launch_func.getOperands());
execute_inputs.emplace_back(compile_op->getResult(1));
Operation* execute_op =
BuildExecuteOp(execute_inputs, launch_func, builder);
tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
tpu_device_assignment.execution_devices, replicate, execute_op,
builder);

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -34,10 +35,6 @@ namespace mlir {
namespace TFTPU {
namespace {
constexpr char kXlaShardingAttr[] = "_XlaSharding";
constexpr char kInputShardingAttr[] = "input_sharding_configuration";
constexpr char kOutputShardingAttr[] = "output_sharding_configuration";
struct TPUShardingIdentificationPass
: public ModulePass<TPUShardingIdentificationPass> {
void runOnModule() override;
@ -68,13 +65,6 @@ void GetAdjacentToXlaShardingOp(
}
}
llvm::Optional<StringRef> ParseShardingAttribute(Operation* operation) {
const auto& sharding_attr =
operation->getAttrOfType<StringAttr>(kXlaShardingAttr);
if (!sharding_attr) return llvm::Optional<StringRef>();
return sharding_attr.getValue();
}
// Parse XlaSharding op connected to input args. If Input to
// tf_device.LaunchFunc op is of resource type, then XlaSharding op
// will be connected to following ReadVariable op.
@ -97,7 +87,7 @@ llvm::Optional<StringRef> ParseInputSharding(const FuncOp func,
}
if (!parsed_sharding_op) return llvm::Optional<StringRef>();
return ParseShardingAttribute(parsed_sharding_op->getOperation());
return tensorflow::ParseShardingAttribute(parsed_sharding_op->getOperation());
}
// If operand of return value of tf_device.LaunchFunc op is directly from
@ -107,7 +97,7 @@ llvm::Optional<StringRef> ParseReturnValueSharding(FuncOp func,
const OpOperand& operand) {
if (auto sharding_op =
llvm::dyn_cast<TF::XlaShardingOp>(operand.get().getDefiningOp())) {
return ParseShardingAttribute(sharding_op.getOperation());
return tensorflow::ParseShardingAttribute(sharding_op.getOperation());
}
return llvm::Optional<StringRef>();
@ -153,8 +143,8 @@ void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) {
if (!input_arg_sharding.hasValue()) continue;
sharding_for_args[arg_index] = input_arg_sharding->str();
}
SetShardingConfigurationAsAttribute(launch_func, kInputShardingAttr,
sharding_for_args);
SetShardingConfigurationAsAttribute(
launch_func, tensorflow::kInputShardingAttr, sharding_for_args);
// By default return values from logical core 0 is used if no sharding
// configuration is defined.
@ -176,8 +166,8 @@ void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) {
sharding_for_return_values[return_value_index] =
return_val_sharding->str();
}
SetShardingConfigurationAsAttribute(launch_func, kOutputShardingAttr,
sharding_for_return_values);
SetShardingConfigurationAsAttribute(
launch_func, tensorflow::kOutputShardingAttr, sharding_for_return_values);
}
void TPUShardingIdentificationPass::runOnModule() {

View File

@ -0,0 +1,86 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
const char* const kXlaShardingAttrName = "_XlaSharding";
const char* const kInputShardingAttr = "input_sharding_configuration";
const char* const kOutputShardingAttr = "output_sharding_configuration";
llvm::Optional<mlir::StringRef> ParseShardingAttribute(
mlir::Operation* operation) {
const auto& sharding_attr =
operation->getAttrOfType<mlir::StringAttr>(kXlaShardingAttrName);
if (!sharding_attr) return llvm::Optional<mlir::StringRef>();
return sharding_attr.getValue();
}
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4>
ExtractInputsForLogicalDevices(int num_logical_cores,
mlir::tf_device::LaunchFuncOp launch_func) {
// Initialize the input list for each logical devices.
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4> input_list;
input_list.reserve(num_logical_cores);
for (int i = 0; i < num_logical_cores; ++i)
input_list.emplace_back(llvm::SmallVector<mlir::Value, 4>());
llvm::SmallVector<mlir::Value, 4> launch_func_inputs(
launch_func.getOperands());
auto sharding_attrs =
launch_func.getOperation()->getAttrOfType<mlir::ArrayAttr>(
kInputShardingAttr);
// If sharding attribute does not exist, then all inputs are placed on 0th
// logical core by default.
if (!sharding_attrs) {
input_list[0] = launch_func_inputs;
return input_list;
}
// Enumerate sharding configuration for each inputs. If input has replicate
// sharding then all logical devices take the value as input. If input has
// maximal sharding then only the specified logical device take the value as
// the input.
for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) {
const auto& sharding_attr = sharding_attr_and_index.value();
const auto input_index = sharding_attr_and_index.index();
const auto& input_value = launch_func_inputs[input_index];
xla::OpSharding sharding;
sharding.ParseFromString(
sharding_attr.cast<mlir::StringAttr>().getValue().str());
const auto input_sharing_type = sharding.type();
if (input_sharing_type == xla::OpSharding::OTHER)
launch_func.emitError(
"tiled inputs are not yet supported for model parallelism");
if (input_sharing_type == xla::OpSharding::REPLICATED) {
for (auto inputs : input_list) inputs.emplace_back(input_value);
} else {
assert(input_sharing_type == xla::OpSharding::MAXIMAL);
const int logical_device_id = sharding.tile_assignment_devices(0);
input_list[logical_device_id].emplace_back(input_value);
}
}
return input_list;
}
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_
#include "llvm/ADT/MapVector.h"
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
namespace tensorflow {
extern const char* const kXlaShardingAttrName;
extern const char* const kInputShardingAttr;
extern const char* const kOutputShardingAttr;
// Parse "_XlaSharding" attribute from operation, if it exists.
llvm::Optional<mlir::StringRef> ParseShardingAttribute(
mlir::Operation* operation);
// Parses "input_sharding_configuration" attribute and returns a list where
// i-th element is a list of mlir::Value's which represent inputs for the
// TPU computation correponding to i-th logical device. If the attribute
// does not exist, the all inputs are placed on logical core 0.
llvm::SmallVector<llvm::SmallVector<mlir::Value, 4>, 4>
ExtractInputsForLogicalDevices(int num_logical_cores,
mlir::tf_device::LaunchFuncOp launch_func);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_