Implement outside compilation head extraction.

PiperOrigin-RevId: 311172756
Change-Id: Id3dbcbd1582a01ec94424dbb8b08bb475466568c
This commit is contained in:
A. Unique TensorFlower 2020-05-12 11:57:52 -07:00 committed by TensorFlower Gardener
parent 1712a14d01
commit 1de39b5756
3 changed files with 247 additions and 32 deletions

View File

@ -1,13 +1,17 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-head-tail-outside-compilation | FileCheck %s --dump-input-on-failure
// Tests extraction of a single outside compiled cluster with no input or output dependecies.
// Tests extraction of a outside compiled ops at head of TPU computation.
// CHECK-LABEL: func @nodep_single_head_outside_compilation
func @nodep_single_head_outside_compilation() -> () {
// CHECK: "tf.A"
// CHECK-NEXT: "tf_device.launch"
"tf_device.launch"() ( {
"tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
// CHECK: tf_device.launch
// CHECK: "tf.A"
// CHECK-NEXT: tf_device.return
//
// CHECK: "tf_device.cluster"
// CHECK: "tf.C"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
"tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
"tf.B"() : () -> ()
"tf.C"() : () -> ()
tf_device.return
@ -15,15 +19,62 @@ func @nodep_single_head_outside_compilation() -> () {
return
}
// CHECK-LABEL: func @nodep_multiple_head_outside_compilation
func @nodep_multiple_head_outside_compilation() -> () {
// CHECK: "tf.A"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf_device.launch"
"tf_device.launch"() ( {
"tf.A"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> ()
// CHECK-LABEL: func @multiple_head_outside_compilation
func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK: %[[A_OUT:.*]] = "tf.A"
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
// CHECK: "tf.C"
// CHECK-NEXT: tf_device.return %[[B_OUT]]
//
// CHECK: "tf_device.cluster"
// CHECK: "tf.D"(%[[LAUNCH_OUT]])
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
"tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> ()
"tf.D"(%1) : (tensor<i32>) -> ()
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
}
// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle
func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor<i32>) -> () {
// CHECK-NOT: tf_device.launch
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%0 = "tf.A"(%arg0) {} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
"tf.C"(%1) : (tensor<i32>) -> ()
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
}
// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted
func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor<i32>) -> () {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK: %[[A_OUT:.*]] = "tf.A"
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
// CHECK-NEXT: tf_device.return %[[D_OUT]]
//
// CHECK: "tf_device.cluster"
// CHECK: "tf.B"
// CHECK: "tf.C"
// CHECK: "tf.E"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
%1 = "tf.B"() {} : () -> (tensor<i32>)
%2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
%3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor<i32>) -> (tensor<i32>)
%4 = "tf.E"(%3) {} : (tensor<i32>) -> (tensor<i32>)
tf_device.return
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return

View File

@ -258,7 +258,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateTPUVariableReformattingPass();
// Creates a pass that extracts outside compilation (CPU ops inside TPU cluster)
// at head/tail of TPU cluster to run before/after TPU computation.
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractHeadTailOutsideCompilationPass();
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)

View File

@ -14,11 +14,23 @@ limitations under the License.
==============================================================================*/
#include <memory>
#include <type_traits>
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
namespace mlir {
namespace TFTPU {
@ -30,30 +42,182 @@ namespace {
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
struct TPUExtractHeadTailOutsideCompilation
: public PassWrapper<TPUExtractHeadTailOutsideCompilation, FunctionPass> {
void runOnFunction() override;
};
bool HasOutsideCompilationAttribute(Operation* op) {
return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
}
void TPUExtractHeadTailOutsideCompilation::runOnFunction() {
getFunction().walk([&](tf_device::LaunchOp launch) {
Block& launch_block = launch.GetBody();
for (auto& op : llvm::make_early_inc_range(launch_block.getOperations())) {
// TODO(b/155115766): Handle outputs that should be inputs to TPU
// LaunchOp.
if (auto attr =
op.getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
op.moveBefore(launch);
} else {
// Returns whether all operands of `op` are from values inside the
// `input_value_set`.
bool OpContainsOperandsFromSet(Operation* op,
const llvm::SetVector<Value>& input_value_set) {
for (auto operand : op->getOperands())
if (input_value_set.count(operand) == 0) return false;
return true;
}
void RecordOutsideCompiledOpsAndUsages(
Operation* op, llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops,
llvm::SetVector<Value>* outside_compiled_op_usages) {
if (HasOutsideCompilationAttribute(op) &&
OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) {
outside_compiled_ops->insert(op);
outside_compiled_op_usages->insert(op->getResults().begin(),
op->getResults().end());
}
}
// Traverses the MLIR graph and returns a set of ops that
// are connected to inputs of TPU computation and outside compiled.
void ExtractOutsideCompiledOpsConnectedToHead(
Value input_value, llvm::SetVector<Value>* values_used_in_host_cluster,
llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops) {
llvm::SmallSetVector<Operation*, 4> parent_outside_compiled_ops_at_head;
for (auto& usage : input_value.getUses()) {
auto head_operation = usage.getOwner();
RecordOutsideCompiledOpsAndUsages(head_operation,
&parent_outside_compiled_ops_at_head,
values_used_in_host_cluster);
}
// Traverse the graph and find all outside compiled ops connected from
// the `input_value`.
while (!parent_outside_compiled_ops_at_head.empty()) {
llvm::SmallSetVector<Operation*, 4> connected_outside_compiled_ops;
for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) {
auto op_results = head_outside_compiled_op->getOpResults();
for (auto op_result : op_results) {
for (auto& use : op_result.getUses()) {
auto connected_op = use.getOwner();
RecordOutsideCompiledOpsAndUsages(connected_op,
&connected_outside_compiled_ops,
values_used_in_host_cluster);
}
}
}
outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(),
parent_outside_compiled_ops_at_head.end());
std::swap(parent_outside_compiled_ops_at_head,
connected_outside_compiled_ops);
}
}
// TODO(hongjunchoi): Also handle ops without inputs that are outside
// compiled.
//
// Returns set of ops that are outside compiled and are directly connected
// to inputs to the TPU computation.
llvm::SmallSetVector<Operation*, 4> IdentifyOutsideCompiledOpsAtHead(
tf_device::ClusterOp tpu_cluster) {
llvm::SmallSetVector<Operation*, 4> outside_compiled_at_head_ops;
llvm::SetVector<Value> values_used_in_cluster;
auto& cluster_region = tpu_cluster.body();
getUsedValuesDefinedAbove(cluster_region, cluster_region,
values_used_in_cluster);
auto input_value_list = llvm::to_vector<8>(values_used_in_cluster);
for (auto input_value : input_value_list)
ExtractOutsideCompiledOpsConnectedToHead(
input_value, &values_used_in_cluster, &outside_compiled_at_head_ops);
return outside_compiled_at_head_ops;
}
// Returns output values of extracted outside compiled cluster at head that
// are used by the TPU computation.
llvm::SmallVector<Value, 8> GetHeadExtractedClusterOutputs(
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
llvm::SmallVector<Value, 8> outputs;
outputs.reserve(head_outside_compiled_ops.size());
for (auto op : head_outside_compiled_ops) {
for (Operation* user : op->getUsers()) {
if (!head_outside_compiled_ops.count(user)) {
outputs.append(op->result_begin(), op->result_end());
break;
}
}
}
return outputs;
}
// Creates new tf_device.launch op with outside compiled ops extracted
// from the head of TPU computation.
llvm::Optional<tf_device::LaunchOp> IsolateHeadExtractedOpsToLaunchOp(
OpBuilder* builder, tf_device::ClusterOp cluster,
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
if (head_outside_compiled_ops.empty())
return llvm::Optional<tf_device::LaunchOp>();
// Create tf_device.launch op to separate all extracted outside compiled ops
// before the tf_device.cluster.
auto output_values =
GetHeadExtractedClusterOutputs(head_outside_compiled_ops);
llvm::SmallVector<Type, 8> output_return_types;
output_return_types.reserve(output_values.size());
for (auto output : output_values)
output_return_types.emplace_back(output.getType());
builder->setInsertionPoint(cluster);
auto host_launch_op = builder->create<tf_device::LaunchOp>(
cluster.getLoc(), builder->getStringAttr(""), output_return_types);
// Replace all usages of outside compiled ops that are used in TPU
// computation with the results of the above created launch op.
for (auto output_and_index : llvm::enumerate(output_values)) {
auto output_index = output_and_index.index();
auto output = output_and_index.value();
for (auto& use : output.getUses()) {
if (!head_outside_compiled_ops.count(use.getOwner()))
use.set(host_launch_op.getResult(output_index));
}
}
// Create terminator op for the newly created launch op.
host_launch_op.body().push_back(new Block());
builder->setInsertionPointToEnd(&host_launch_op.GetBody());
auto terminator = builder->create<tf_device::ReturnOp>(
host_launch_op.getLoc(), output_values);
// Move all outside compile ops from cluster op to launch op.
for (auto outside_compiled_op : head_outside_compiled_ops)
outside_compiled_op->moveBefore(terminator);
return host_launch_op;
}
struct TPUExtractHeadTailOutsideCompilation
: public PassWrapper<TPUExtractHeadTailOutsideCompilation,
OperationPass<ModuleOp>> {
void runOnOperation() override;
};
void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
// Get runtime devices information from the closest parent module.
auto module = getOperation();
mlir::TF::RuntimeDevices devices;
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
return signalPassFailure();
OpBuilder builder(&getContext());
module.walk([&](tf_device::ClusterOp cluster) {
auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster);
IsolateHeadExtractedOpsToLaunchOp(&builder, cluster,
head_outside_compiled_ops);
// TODO(b/156030523): Update device attribute of newly created host launch
// op as well as enclosing Replicate op (if TPU computation is replicated)
// with host device names.
// TODO(b/155115766): Implement tail outside compiled op extraction.
});
}
} // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>>
std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractHeadTailOutsideCompilationPass() {
return std::make_unique<TPUExtractHeadTailOutsideCompilation>();
}