From 7dcc0a93089cc868f8e1dbd2091ff9f13c004fd4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 3 Mar 2020 17:14:20 -0800 Subject: [PATCH] Add feature to support Maximal/Replicate input sharding for model parallelism. PiperOrigin-RevId: 298730105 Change-Id: I983bb575b367322e97c6a24803a6c8e7f8892acd --- tensorflow/compiler/mlir/tensorflow/BUILD | 17 ++++ .../tensorflow/transforms/tpu_rewrite_pass.cc | 33 ++++--- .../tpu_sharding_identification_pass.cc | 24 ++---- .../tensorflow/utils/xla_sharding_util.cc | 86 +++++++++++++++++++ .../mlir/tensorflow/utils/xla_sharding_util.h | 44 ++++++++++ 5 files changed, 174 insertions(+), 30 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index d52fd0c3b72..6e315509ed9 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -337,6 +337,7 @@ cc_library( ":tensorflow_types", ":tpu_rewrite_device_util", ":translate_utils", + ":xla_sharding_util", "//tensorflow/compiler/mlir/lite:validators", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", @@ -1154,3 +1155,19 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +cc_library( + name = "xla_sharding_util", + srcs = [ + "utils/xla_sharding_util.cc", + ], + hdrs = [ + "utils/xla_sharding_util.h", + ], + deps = [ + ":tensorflow", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 7b0291a2f9b..3e85cb57318 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -353,23 +354,17 @@ Operation* BuildCompileOp( compilation_device); } -// Creates a `tf.TPUExecute` op that executes TPU program generated by -// `compile_op`. -Operation* BuildExecuteOp(Operation* compile_op, +// Creates a `tf.TPUExecute` op that executes TPU program. +Operation* BuildExecuteOp(llvm::ArrayRef inputs, tf_device::LaunchFuncOp launch_func, OpBuilder* builder) { - // TPUExecute inherits all launch_func inputs, and takes an additional input - // for compilation cache key. - llvm::SmallVector tensor_inputs(launch_func.getOperands()); - tensor_inputs.push_back(compile_op->getResult(1)); - // TODO(b/139377366): Need to snapshot all resource variable inputs in // follow-up CLs. // TPUExecute has same output types as launch_func. - return builder->create( - launch_func.getLoc(), launch_func.getResultTypes(), tensor_inputs, - llvm::ArrayRef{}); + return builder->create(launch_func.getLoc(), + launch_func.getResultTypes(), inputs, + llvm::ArrayRef{}); } // Creates a tf_device.parallel_execute op that wraps TPUExecute op to @@ -394,7 +389,14 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op, auto parallel_execute_op = builder->create( launch_func.getLoc(), num_logical_cores, concatenated_output_types); + // Extract inputs for each region of the parallel_execute op. The i-th + // element in the list represents the input lists to TPU computation for + // i-th logical core. + auto input_list = tensorflow::ExtractInputsForLogicalDevices( + num_logical_cores, launch_func); + // For each logical core, create a region with TPUExecute op. + assert(input_list.size() == num_logical_cores); for (int core_id = 0; core_id < num_logical_cores; ++core_id) { auto& region = parallel_execute_op.GetRegionBlockWithIndex(core_id); builder->setInsertionPointToEnd(®ion); @@ -404,7 +406,9 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op, // TODO(b/148913294): Identify inputs/return values specific to each // logical core TPU execution by parsing xla_sharding op in // launch_func. - auto execute = BuildExecuteOp(compile_op, launch_func, builder); + auto execute_inputs = input_list[core_id]; + execute_inputs.emplace_back(compile_op->getResult(1 + core_id)); + auto execute = BuildExecuteOp(execute_inputs, launch_func, builder); // Create a launch op for each region of parallel_execute. // @@ -592,7 +596,10 @@ LogicalResult Rewrite( // TODO(hongjunchoi): Correctly parse TPU topology and assign logical device // attributes to launch_op's within parallel_execute op. } else { - Operation* execute_op = BuildExecuteOp(compile_op, launch_func, builder); + llvm::SmallVector execute_inputs(launch_func.getOperands()); + execute_inputs.emplace_back(compile_op->getResult(1)); + Operation* execute_op = + BuildExecuteOp(execute_inputs, launch_func, builder); tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute( tpu_device_assignment.execution_devices, replicate, execute_op, builder); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index 244df85f482..bbe3b80ae36 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,10 +35,6 @@ namespace mlir { namespace TFTPU { namespace { -constexpr char kXlaShardingAttr[] = "_XlaSharding"; -constexpr char kInputShardingAttr[] = "input_sharding_configuration"; -constexpr char kOutputShardingAttr[] = "output_sharding_configuration"; - struct TPUShardingIdentificationPass : public ModulePass { void runOnModule() override; @@ -68,13 +65,6 @@ void GetAdjacentToXlaShardingOp( } } -llvm::Optional ParseShardingAttribute(Operation* operation) { - const auto& sharding_attr = - operation->getAttrOfType(kXlaShardingAttr); - if (!sharding_attr) return llvm::Optional(); - return sharding_attr.getValue(); -} - // Parse XlaSharding op connected to input args. If Input to // tf_device.LaunchFunc op is of resource type, then XlaSharding op // will be connected to following ReadVariable op. @@ -97,7 +87,7 @@ llvm::Optional ParseInputSharding(const FuncOp func, } if (!parsed_sharding_op) return llvm::Optional(); - return ParseShardingAttribute(parsed_sharding_op->getOperation()); + return tensorflow::ParseShardingAttribute(parsed_sharding_op->getOperation()); } // If operand of return value of tf_device.LaunchFunc op is directly from @@ -107,7 +97,7 @@ llvm::Optional ParseReturnValueSharding(FuncOp func, const OpOperand& operand) { if (auto sharding_op = llvm::dyn_cast(operand.get().getDefiningOp())) { - return ParseShardingAttribute(sharding_op.getOperation()); + return tensorflow::ParseShardingAttribute(sharding_op.getOperation()); } return llvm::Optional(); @@ -153,8 +143,8 @@ void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) { if (!input_arg_sharding.hasValue()) continue; sharding_for_args[arg_index] = input_arg_sharding->str(); } - SetShardingConfigurationAsAttribute(launch_func, kInputShardingAttr, - sharding_for_args); + SetShardingConfigurationAsAttribute( + launch_func, tensorflow::kInputShardingAttr, sharding_for_args); // By default return values from logical core 0 is used if no sharding // configuration is defined. @@ -176,8 +166,8 @@ void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) { sharding_for_return_values[return_value_index] = return_val_sharding->str(); } - SetShardingConfigurationAsAttribute(launch_func, kOutputShardingAttr, - sharding_for_return_values); + SetShardingConfigurationAsAttribute( + launch_func, tensorflow::kOutputShardingAttr, sharding_for_return_values); } void TPUShardingIdentificationPass::runOnModule() { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc new file mode 100644 index 00000000000..57e51b2c812 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // TF:llvm-project +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace tensorflow { + +const char* const kXlaShardingAttrName = "_XlaSharding"; +const char* const kInputShardingAttr = "input_sharding_configuration"; +const char* const kOutputShardingAttr = "output_sharding_configuration"; + +llvm::Optional ParseShardingAttribute( + mlir::Operation* operation) { + const auto& sharding_attr = + operation->getAttrOfType(kXlaShardingAttrName); + if (!sharding_attr) return llvm::Optional(); + return sharding_attr.getValue(); +} + +llvm::SmallVector, 4> +ExtractInputsForLogicalDevices(int num_logical_cores, + mlir::tf_device::LaunchFuncOp launch_func) { + // Initialize the input list for each logical devices. + llvm::SmallVector, 4> input_list; + input_list.reserve(num_logical_cores); + for (int i = 0; i < num_logical_cores; ++i) + input_list.emplace_back(llvm::SmallVector()); + + llvm::SmallVector launch_func_inputs( + launch_func.getOperands()); + auto sharding_attrs = + launch_func.getOperation()->getAttrOfType( + kInputShardingAttr); + // If sharding attribute does not exist, then all inputs are placed on 0th + // logical core by default. + if (!sharding_attrs) { + input_list[0] = launch_func_inputs; + return input_list; + } + + // Enumerate sharding configuration for each inputs. If input has replicate + // sharding then all logical devices take the value as input. If input has + // maximal sharding then only the specified logical device take the value as + // the input. + for (const auto& sharding_attr_and_index : llvm::enumerate(sharding_attrs)) { + const auto& sharding_attr = sharding_attr_and_index.value(); + const auto input_index = sharding_attr_and_index.index(); + const auto& input_value = launch_func_inputs[input_index]; + + xla::OpSharding sharding; + sharding.ParseFromString( + sharding_attr.cast().getValue().str()); + + const auto input_sharing_type = sharding.type(); + if (input_sharing_type == xla::OpSharding::OTHER) + launch_func.emitError( + "tiled inputs are not yet supported for model parallelism"); + + if (input_sharing_type == xla::OpSharding::REPLICATED) { + for (auto inputs : input_list) inputs.emplace_back(input_value); + } else { + assert(input_sharing_type == xla::OpSharding::MAXIMAL); + const int logical_device_id = sharding.tile_assignment_devices(0); + input_list[logical_device_id].emplace_back(input_value); + } + } + return input_list; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h new file mode 100644 index 00000000000..7c5bad9fc16 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ + +#include "llvm/ADT/MapVector.h" +#include "mlir/IR/Operation.h" // TF:llvm-project +#include "mlir/IR/Value.h" // TF:llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" + +namespace tensorflow { + +extern const char* const kXlaShardingAttrName; +extern const char* const kInputShardingAttr; +extern const char* const kOutputShardingAttr; + +// Parse "_XlaSharding" attribute from operation, if it exists. +llvm::Optional ParseShardingAttribute( + mlir::Operation* operation); + +// Parses "input_sharding_configuration" attribute and returns a list where +// i-th element is a list of mlir::Value's which represent inputs for the +// TPU computation correponding to i-th logical device. If the attribute +// does not exist, the all inputs are placed on logical core 0. +llvm::SmallVector, 4> +ExtractInputsForLogicalDevices(int num_logical_cores, + mlir::tf_device::LaunchFuncOp launch_func); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_