Add MLIR transformation pass that identifies xla sharding op connected to

inputs/outputs of TPU computation. This information will later be used to
correctly fan-in/fan-out inputs/outputs of TPU computation from logical cores.

PiperOrigin-RevId: 297952065
Change-Id: I1830eb9139a68454c158dc1979139152772ebf82
This commit is contained in:
A. Unique TensorFlower 2020-02-28 15:45:44 -08:00 committed by TensorFlower Gardener
parent 3a8c575c4e
commit 9cb019e654
4 changed files with 0 additions and 354 deletions

View File

@ -304,7 +304,6 @@ cc_library(
"transforms/tpu_dynamic_padding_mapper.cc",
"transforms/tpu_merge_variables_with_execute.cc",
"transforms/tpu_rewrite_pass.cc",
"transforms/tpu_sharding_identification_pass.cc",
"transforms/tpu_variable_runtime_reformatting.cc",
"translate/breakup-islands.cc",
"translate/control_to_executor_dialect.cc",
@ -336,7 +335,6 @@ cc_library(
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -1,147 +0,0 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-sharding-identification | FileCheck %s --dump-input=fail
// Tests empty launch func. Empty input/output sharding configuration
// attributes must be added.
// CHECK-LABEL: func @check_sharding_attrs_exists_for_empty_launch_func
func @check_sharding_attrs_exists_for_empty_launch_func() {
"tf_device.launch_func"() {device = "", func = @empty_func, step_marker_location = ""} : () -> ()
// CHECK: input_sharding_configuration = []
// CHECK: output_sharding_configuration = []
return
}
func @empty_func() {
return
}
// -----
// Tests with a inputs/outputs with no xla sharding op attached gets
// default maximal(0) sharding configuration.
// CHECK-LABEL: func @check_default_sharding_for_inputs_outputs
func @check_default_sharding_for_inputs_outputs(%arg0: tensor<*xi32>) {
"tf_device.launch_func"(%arg0) {device = "", func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> ()
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
return
}
func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.A"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}
// -----
// Tests with a input arg connected to XlaSharding op.
// CHECK-LABEL: func @check_sharding_for_input_correctly_identified
func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) {
"tf_device.launch_func"(%arg0) {device = "", func = @inputs_with_sharding_func, step_marker_location = ""} : (tensor<*xi32>) -> ()
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\01\02\03"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
return
}
func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
return %1 : tensor<*xi32>
}
// -----
// Tests with sharding is correctly parsed for multiple inputs/outputs.
// CHECK-LABEL: func @check_sharding_for_multiple_inputs_outputs
func @check_sharding_for_multiple_inputs_outputs(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
return
}
func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%2, %3 = "tf.A"(%0, %1) : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.XlaSharding"(%2) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %5 : tensor<*xi32> , tensor<*xi1>
}
// -----
// Tests with input sharding following an identity op.
// CHECK-LABEL: func @check_sharding_after_identity
func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_identity, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
return
}
func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%2 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%3, %4 = "tf.A"(%1, %2) : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %5, %6 : tensor<*xi32> , tensor<*xi1>
}
// -----
// Tests with input sharding following a ReadVariable op.
// CHECK-LABEL: func @check_sharding_after_read_variable
func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_read_variable, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
return
}
func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource<tensor<32xf32>>>, %arg1: tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32>
%2 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%3 = "tf.Identity"(%2) : (tensor<32xf32>) -> tensor<32xf32>
%4 = "tf.XlaSharding"(%3) { _XlaSharding = "\04\05\06" } : (tensor<32xf32>) -> tensor<32xf32>
%5, %6 = "tf.A"(%1, %3) : (tensor<32xf32>, tensor<32xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%8 = "tf.XlaSharding"(%6) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %7, %8 : tensor<*xi32> , tensor<*xi1>
}
// -----
// Tests with input sharding following an identity op and cast op.
// CHECK-LABEL: func @check_sharding_after_cast_op
func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_cast, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
return
}
func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.Cast"(%0) : (tensor<*xi32>) -> tensor<*xi1>
%2 = "tf.XlaSharding"(%1) { _XlaSharding = "\01\02\03" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%4, %5 = "tf.A"(%2, %3) : (tensor<*xi1>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %6, %7 : tensor<*xi32> , tensor<*xi1>
}

View File

@ -209,10 +209,6 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
// ops.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPURewritePass();
// Creates a pass that identifies XLASharding ops in launch op for TPU
// computation.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUShardingIdentificationPass();
// Creates a pass that merges device variable reads/updates into the surrounded
// TPUExecute node. This allows the execute node to perform in-place variable
// updates.

View File

@ -1,201 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/UseDefLists.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kXlaShardingAttr[] = "_XlaSharding";
constexpr char kInputShardingAttr[] = "input_sharding_configuration";
constexpr char kOutputShardingAttr[] = "output_sharding_configuration";
struct TPUShardingIdentificationPass
: public ModulePass<TPUShardingIdentificationPass> {
void runOnModule() override;
};
// XlaSharding op may be direct user of inputs but it may also be followed by
// an Identity op and, in the case where bfloat16 type is used, Cast op may be
// added right after the input. As so, parse the users of the operation to
// access connected XlaSharding op.
//
// TODO(hongjunchoi): Consider explicitly checking op patterns to detect
// sharded inputs.
void GetAdjacentToXlaShardingOp(
Operation* op, llvm::Optional<TF::XlaShardingOp>* sharding_op) {
// TODO(hongjunchoi): Detect the case when sharding configuration is
// ambiguous for a single input (i.e. multiple different XlaSharding ops
// with different configuration policies are connected).
if (sharding_op->hasValue()) return;
if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(op)) {
sharding_op->emplace(sharding);
return;
}
if (llvm::isa<TF::IdentityOp>(op) || llvm::isa<TF::CastOp>(op)) {
for (auto user : op->getUsers())
GetAdjacentToXlaShardingOp(user, sharding_op);
}
}
llvm::Optional<StringRef> ParseShardingAttribute(Operation* operation) {
const auto& sharding_attr =
operation->getAttrOfType<StringAttr>(kXlaShardingAttr);
if (!sharding_attr) return llvm::Optional<StringRef>();
return sharding_attr.getValue();
}
// Parse XlaSharding op connected to input args. If Input to
// tf_device.LaunchFunc op is of resource type, then XlaSharding op
// will be connected to following ReadVariable op.
//
// TODO(hongjunchoi): Add logic to parse XlaSharding op inside a
// Call op or if/while op.
llvm::Optional<StringRef> ParseInputSharding(const FuncOp func,
const int arg_index,
const Value& arg) {
llvm::Optional<TF::XlaShardingOp> parsed_sharding_op;
for (auto user : arg.getUsers()) {
if (parsed_sharding_op) continue;
GetAdjacentToXlaShardingOp(user, &parsed_sharding_op);
if (parsed_sharding_op) continue;
if (llvm::isa<TF::ReadVariableOp>(user))
for (auto user : user->getUsers())
GetAdjacentToXlaShardingOp(user, &parsed_sharding_op);
}
if (!parsed_sharding_op) return llvm::Optional<StringRef>();
return ParseShardingAttribute(parsed_sharding_op->getOperation());
}
// If operand of return value of tf_device.LaunchFunc op is directly from
// XlaSharding op, return the provided sharding configuration.
llvm::Optional<StringRef> ParseReturnValueSharding(FuncOp func,
const int output_index,
const OpOperand& operand) {
if (auto sharding_op =
llvm::dyn_cast<TF::XlaShardingOp>(operand.get().getDefiningOp())) {
return ParseShardingAttribute(sharding_op.getOperation());
}
return llvm::Optional<StringRef>();
}
// Add parsed sharding configuration to tf_device.LaunchFunc op attribute.
void SetShardingConfigurationAsAttribute(
tf_device::LaunchFuncOp launch_func, const std::string& attr_name,
const llvm::SmallVector<std::string, 8>& sharding_config) {
auto input_sharding_array_ref = llvm::SmallVector<llvm::StringRef, 8>(
sharding_config.begin(), sharding_config.end());
launch_func.setAttr(attr_name,
mlir::Builder(launch_func.getContext())
.getStrArrayAttr(input_sharding_array_ref));
}
// If XlaSharding op is connected to input/output of the tf_device.LaunchFuncOp,
// then add attributes to the op specifying the sharding configurations.
void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) {
// Look up function definition from module.
FuncOp func = launch_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
launch_func.func());
Block& func_entry_block = func.getBody().getBlocks().front();
// By default inputs have maximal sharding and inputs are assigned to
// logical core 0 if no sharding is defined.
llvm::SmallVector<std::string, 8> sharding_for_args(
func_entry_block.getNumArguments(),
xla::sharding_builder::AssignDevice(0).SerializeAsString());
// Iterate through input arguments to the entry block of tf_device.LaunchFunc.
// For input ops, look for following XlaSharding ops. XlaSharding ops can
// 1) Directly follow the input argument if input argument has non-resource
// types.
// 2) Follow ReadVariableOp if the input type is of resource type.
// 3) Follow IdentityOp or CastOp after above cases (1), (2).
for (auto& arg_index_and_value :
llvm::enumerate(func_entry_block.getArguments())) {
const int arg_index = arg_index_and_value.index();
auto& arg = arg_index_and_value.value();
auto input_arg_sharding = ParseInputSharding(func, arg_index, arg);
if (!input_arg_sharding.hasValue()) continue;
sharding_for_args[arg_index] = input_arg_sharding->str();
}
SetShardingConfigurationAsAttribute(launch_func, kInputShardingAttr,
sharding_for_args);
// By default return values from logical core 0 is used if no sharding
// configuration is defined.
llvm::SmallVector<std::string, 8> sharding_for_return_values(
func_entry_block.getTerminator()->getNumOperands(),
xla::sharding_builder::AssignDevice(0).SerializeAsString());
// Iterate through operands of the terminator, if the preceding op is
// XlaShardingOp, then add provided sharding configuration to launch func
// attribute.
for (auto& return_value_and_index :
llvm::enumerate(func_entry_block.getTerminator()->getOpOperands())) {
int return_value_index = return_value_and_index.index();
const auto& return_value = return_value_and_index.value();
auto return_val_sharding =
ParseReturnValueSharding(func, return_value_index, return_value);
if (return_val_sharding)
sharding_for_return_values[return_value_index] =
return_val_sharding->str();
}
SetShardingConfigurationAsAttribute(launch_func, kOutputShardingAttr,
sharding_for_return_values);
}
void TPUShardingIdentificationPass::runOnModule() {
getModule().walk([&](tf_device::LaunchFuncOp launch_func) {
IdentifyXlaShardingForTPUComputation(launch_func);
});
}
} // anonymous namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUShardingIdentificationPass() {
return std::make_unique<TPUShardingIdentificationPass>();
}
static PassRegistration<TPUShardingIdentificationPass> pass(
"tf-tpu-sharding-identification",
"Identifies and handles inputs/outputs of TPU computation that is "
"sharded across logical cores.");
} // namespace TFTPU
} // namespace mlir