From 1de39b575611a252531d0238eefb8a394fa96286 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 12 May 2020 11:57:52 -0700 Subject: [PATCH] Implement outside compilation head extraction. PiperOrigin-RevId: 311172756 Change-Id: Id3dbcbd1582a01ec94424dbb8b08bb475466568c --- ...extract_head_tail_outside_compilation.mlir | 83 ++++++-- .../mlir/tensorflow/transforms/passes.h | 2 +- ...u_extract_head_tail_outside_compilation.cc | 194 ++++++++++++++++-- 3 files changed, 247 insertions(+), 32 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir index 77ca08c089a..eb67bdcc914 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -1,13 +1,17 @@ // RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-head-tail-outside-compilation | FileCheck %s --dump-input-on-failure -// Tests extraction of a single outside compiled cluster with no input or output dependecies. +// Tests extraction of a outside compiled ops at head of TPU computation. -// CHECK-LABEL: func @nodep_single_head_outside_compilation -func @nodep_single_head_outside_compilation() -> () { - // CHECK: "tf.A" - // CHECK-NEXT: "tf_device.launch" - "tf_device.launch"() ( { - "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> () +func @single_head_outside_compilation(%arg0 : tensor) -> () { + // CHECK: tf_device.launch + // CHECK: "tf.A" + // CHECK-NEXT: tf_device.return + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> () "tf.B"() : () -> () "tf.C"() : () -> () tf_device.return @@ -15,15 +19,62 @@ func @nodep_single_head_outside_compilation() -> () { return } -// CHECK-LABEL: func @nodep_multiple_head_outside_compilation -func @nodep_multiple_head_outside_compilation() -> () { - // CHECK: "tf.A" - // CHECK-NEXT: "tf.B" - // CHECK-NEXT: "tf_device.launch" - "tf_device.launch"() ( { - "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> () - "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () - "tf.C"() : () -> () +// CHECK-LABEL: func @multiple_head_outside_compilation +func @multiple_head_outside_compilation(%arg0 : tensor) -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) + // CHECK: "tf.C" + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.D"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + "tf.C"(%1, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> () + "tf.D"(%1) : (tensor) -> () + tf_device.return + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + return +} + +// CHECK-LABEL: func @test_do_not_outside_compiled_ops_in_middle +func @test_do_not_outside_compiled_ops_in_middle(%arg0 : tensor) -> () { + // CHECK-NOT: tf_device.launch + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {} : (tensor) -> (tensor) + %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster1"}: (tensor) -> (tensor) + "tf.C"(%1) : (tensor) -> () + tf_device.return + }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () + return +} + +// CHECK-LABEL: func @test_ops_with_tpu_operands_not_extracted +func @test_ops_with_tpu_operands_not_extracted(%arg0 : tensor) -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]]) + // CHECK-NEXT: tf_device.return %[[D_OUT]] + // + // CHECK: "tf_device.cluster" + // CHECK: "tf.B" + // CHECK: "tf.C" + // CHECK: "tf.E" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %1 = "tf.B"() {} : () -> (tensor) + %2 = "tf.C"(%arg0, %1) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> (tensor) + %3 = "tf.D"(%0) {_xla_outside_compilation = "cluster1"}: (tensor) -> (tensor) + %4 = "tf.E"(%3) {} : (tensor) -> (tensor) tf_device.return }) {device = "tpu0", launch_attr = "launch_attr"} : () -> () return diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index c1d99c2dee3..0b1ff2beebb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -258,7 +258,7 @@ std::unique_ptr> CreateTPUVariableReformattingPass(); // Creates a pass that extracts outside compilation (CPU ops inside TPU cluster) // at head/tail of TPU cluster to run before/after TPU computation. -std::unique_ptr> +std::unique_ptr> CreateTPUExtractHeadTailOutsideCompilationPass(); // Creates a pass that extract outside compilation (CPU ops inside TPU cluster) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index 141feeb6b24..b9e214470cd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -14,11 +14,23 @@ limitations under the License. ==============================================================================*/ #include +#include +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" namespace mlir { namespace TFTPU { @@ -30,30 +42,182 @@ namespace { constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; -struct TPUExtractHeadTailOutsideCompilation - : public PassWrapper { - void runOnFunction() override; -}; +bool HasOutsideCompilationAttribute(Operation* op) { + return op->getAttrOfType(kXlaOutsideCompilationAttr) != nullptr; +} -void TPUExtractHeadTailOutsideCompilation::runOnFunction() { - getFunction().walk([&](tf_device::LaunchOp launch) { - Block& launch_block = launch.GetBody(); - for (auto& op : llvm::make_early_inc_range(launch_block.getOperations())) { - // TODO(b/155115766): Handle outputs that should be inputs to TPU - // LaunchOp. - if (auto attr = - op.getAttrOfType(kXlaOutsideCompilationAttr)) { - op.moveBefore(launch); - } else { +// Returns whether all operands of `op` are from values inside the +// `input_value_set`. +bool OpContainsOperandsFromSet(Operation* op, + const llvm::SetVector& input_value_set) { + for (auto operand : op->getOperands()) + if (input_value_set.count(operand) == 0) return false; + + return true; +} + +void RecordOutsideCompiledOpsAndUsages( + Operation* op, llvm::SmallSetVector* outside_compiled_ops, + llvm::SetVector* outside_compiled_op_usages) { + if (HasOutsideCompilationAttribute(op) && + OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) { + outside_compiled_ops->insert(op); + outside_compiled_op_usages->insert(op->getResults().begin(), + op->getResults().end()); + } +} + +// Traverses the MLIR graph and returns a set of ops that +// are connected to inputs of TPU computation and outside compiled. +void ExtractOutsideCompiledOpsConnectedToHead( + Value input_value, llvm::SetVector* values_used_in_host_cluster, + llvm::SmallSetVector* outside_compiled_ops) { + llvm::SmallSetVector parent_outside_compiled_ops_at_head; + for (auto& usage : input_value.getUses()) { + auto head_operation = usage.getOwner(); + RecordOutsideCompiledOpsAndUsages(head_operation, + &parent_outside_compiled_ops_at_head, + values_used_in_host_cluster); + } + + // Traverse the graph and find all outside compiled ops connected from + // the `input_value`. + while (!parent_outside_compiled_ops_at_head.empty()) { + llvm::SmallSetVector connected_outside_compiled_ops; + for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) { + auto op_results = head_outside_compiled_op->getOpResults(); + for (auto op_result : op_results) { + for (auto& use : op_result.getUses()) { + auto connected_op = use.getOwner(); + RecordOutsideCompiledOpsAndUsages(connected_op, + &connected_outside_compiled_ops, + values_used_in_host_cluster); + } + } + } + + outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(), + parent_outside_compiled_ops_at_head.end()); + std::swap(parent_outside_compiled_ops_at_head, + connected_outside_compiled_ops); + } +} + +// TODO(hongjunchoi): Also handle ops without inputs that are outside +// compiled. +// +// Returns set of ops that are outside compiled and are directly connected +// to inputs to the TPU computation. +llvm::SmallSetVector IdentifyOutsideCompiledOpsAtHead( + tf_device::ClusterOp tpu_cluster) { + llvm::SmallSetVector outside_compiled_at_head_ops; + llvm::SetVector values_used_in_cluster; + auto& cluster_region = tpu_cluster.body(); + getUsedValuesDefinedAbove(cluster_region, cluster_region, + values_used_in_cluster); + + auto input_value_list = llvm::to_vector<8>(values_used_in_cluster); + for (auto input_value : input_value_list) + ExtractOutsideCompiledOpsConnectedToHead( + input_value, &values_used_in_cluster, &outside_compiled_at_head_ops); + return outside_compiled_at_head_ops; +} + +// Returns output values of extracted outside compiled cluster at head that +// are used by the TPU computation. +llvm::SmallVector GetHeadExtractedClusterOutputs( + const llvm::SmallSetVector& head_outside_compiled_ops) { + llvm::SmallVector outputs; + outputs.reserve(head_outside_compiled_ops.size()); + + for (auto op : head_outside_compiled_ops) { + for (Operation* user : op->getUsers()) { + if (!head_outside_compiled_ops.count(user)) { + outputs.append(op->result_begin(), op->result_end()); break; } } + } + + return outputs; +} + +// Creates new tf_device.launch op with outside compiled ops extracted +// from the head of TPU computation. +llvm::Optional IsolateHeadExtractedOpsToLaunchOp( + OpBuilder* builder, tf_device::ClusterOp cluster, + const llvm::SmallSetVector& head_outside_compiled_ops) { + if (head_outside_compiled_ops.empty()) + return llvm::Optional(); + + // Create tf_device.launch op to separate all extracted outside compiled ops + // before the tf_device.cluster. + auto output_values = + GetHeadExtractedClusterOutputs(head_outside_compiled_ops); + + llvm::SmallVector output_return_types; + output_return_types.reserve(output_values.size()); + for (auto output : output_values) + output_return_types.emplace_back(output.getType()); + + builder->setInsertionPoint(cluster); + auto host_launch_op = builder->create( + cluster.getLoc(), builder->getStringAttr(""), output_return_types); + + // Replace all usages of outside compiled ops that are used in TPU + // computation with the results of the above created launch op. + for (auto output_and_index : llvm::enumerate(output_values)) { + auto output_index = output_and_index.index(); + auto output = output_and_index.value(); + for (auto& use : output.getUses()) { + if (!head_outside_compiled_ops.count(use.getOwner())) + use.set(host_launch_op.getResult(output_index)); + } + } + + // Create terminator op for the newly created launch op. + host_launch_op.body().push_back(new Block()); + builder->setInsertionPointToEnd(&host_launch_op.GetBody()); + auto terminator = builder->create( + host_launch_op.getLoc(), output_values); + + // Move all outside compile ops from cluster op to launch op. + for (auto outside_compiled_op : head_outside_compiled_ops) + outside_compiled_op->moveBefore(terminator); + + return host_launch_op; +} + +struct TPUExtractHeadTailOutsideCompilation + : public PassWrapper> { + void runOnOperation() override; +}; + +void TPUExtractHeadTailOutsideCompilation::runOnOperation() { + // Get runtime devices information from the closest parent module. + auto module = getOperation(); + mlir::TF::RuntimeDevices devices; + if (failed(tensorflow::GetDevicesFromOp(module, &devices))) + return signalPassFailure(); + + OpBuilder builder(&getContext()); + module.walk([&](tf_device::ClusterOp cluster) { + auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster); + IsolateHeadExtractedOpsToLaunchOp(&builder, cluster, + head_outside_compiled_ops); + + // TODO(b/156030523): Update device attribute of newly created host launch + // op as well as enclosing Replicate op (if TPU computation is replicated) + // with host device names. + + // TODO(b/155115766): Implement tail outside compiled op extraction. }); } } // anonymous namespace -std::unique_ptr> +std::unique_ptr> CreateTPUExtractHeadTailOutsideCompilationPass() { return std::make_unique(); }