[MLIR:TF/XLA] TPU dynamic layout pass

A pass that allows TPU input layout to be determined after JIT compilation. This
is done by adding run-time ops that interpret compilation result and copy to
device with that layout.

PiperOrigin-RevId: 292970557
Change-Id: If453686abea1618ba6f88016912ea15e0f0b3d4e
This commit is contained in:
Yuanzhong Xu 2020-02-03 11:52:16 -08:00 committed by TensorFlower Gardener
parent 874358298b
commit ab94de29f5
6 changed files with 446 additions and 0 deletions

View File

@ -256,6 +256,7 @@ cc_library(
"transforms/sink_constant.cc",
"transforms/test_side_effect_analysis.cc",
"transforms/tpu_cluster_formation.cc",
"transforms/tpu_dynamic_layout_pass.cc",
"transforms/tpu_dynamic_padding_mapper.cc",
"transforms/tpu_merge_variables_with_execute.cc",
"transforms/tpu_rewrite_pass.cc",

View File

@ -6386,6 +6386,25 @@ occurred during compilation.
);
}
def TF_TPUCopyWithLayoutOp : TF_Op<"TPUCopyWithLayout", [NoSideEffect]> {
let summary = "Op that copies host tensor to device with specified layout.";
let description = [{
For internal use only.
}];
let arguments = (ins
TF_Tensor:$input,
I64Tensor:$layout
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_TPUExecuteOp : TF_Op<"TPUExecute", []> {
let summary = "Op that loads and executes a TPU program on a TPU device.";
@ -6437,6 +6456,27 @@ output. For the internal use of the distributed TPU compiler.
TF_DerivedResultTypeListAttr Tresults = TF_DerivedResultTypeListAttr<0>;
}
def TF_TPUGetLayoutOp : TF_Op<"TPUGetLayoutOp", [NoSideEffect]> {
let summary = [{
Op that retrieves the layout of an input or output determined by TPUCompile.
}];
let description = [{
For internal use only.
}];
let arguments = (ins
TF_StrTensor:$cache_key,
I64Attr:$index,
BoolAttr:$is_output
);
let results = (outs
I64Tensor:$layout
);
}
def TF_TPUReplicatedInputOp : TF_Op<"TPUReplicatedInput", [NoSideEffect]> {
let summary = "Connects N inputs to an N-way replicated TPU computation.";

View File

@ -0,0 +1,150 @@
// RUN: tf-opt %s -split-input-file -tf-tpu-dynamic-layout-pass | FileCheck %s --dump-input=fail
// Tests that the pass can transform non-replicated execution.
// CHECK: func @non_replicated(%[[ARG0:.*]]: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32>
func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
// CHECK: %[[COMPILE:.*]]:2 = "tf._TPUCompileMlir"()
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
// CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
// CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext"
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
// CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
// CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"}
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
// CHECK: "tf.TPUExecute"(%[[COPY0]], %[[COPY1]], %[[COMPILE]]#1) {device = "/device:TPU:0"}
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
return %3 : tensor<i32>
}
// -----
// Tests that the pass does not transform two execute ops sharing the same
// compile op.
// CHECK-LABEL: func @multiple_compile_uses
func @multiple_compile_uses(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
%4:2 = "tf._UnKnownOp_"() : () -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
%5 = "tf.TPUExecute"(%4#0, %4#1, %1#1) {device = "/device:TPU:0"}
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
return %5 : tensor<i32>
}
// -----
// Tests that the pass does not transform when tf.IteratorGetNext is on TPU.
// CHECK-LABEL: func @on_tpu_iter
func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) -> tensor<i32> {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:TPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
return %3 : tensor<i32>
}
// -----
// Tests that the pass does not change unsupported input ops.
// CHECK-LABEL: func @unsupported_ops
func @unsupported_ops(%arg0: tensor<3x3x1x32xf32>) -> tensor<i32> {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2 = "tf._Unknown_"() : () -> tensor<3x3x1x32xf32>
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
%3 = "tf.TPUExecute"(%arg0, %2, %1#1) {device = "/device:TPU:0"}
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
return %3 : tensor<i32>
}
// -----
// Tests that the pass can transform replicated execution.
// CHECK: func @replicated(%[[ARG0:.*]]: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32>
func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
// CHECK: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
// CHECK: %[[COMPILE:.*]]:2 = "tf._TPUCompileMlir"()
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
// CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
// CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
// CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"}
// CHECK: %[[ITER1:.*]]:2 = "tf.IteratorGetNext"
%3:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
// CHECK-DAG: %[[COPY2:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#0, %[[LAYOUT0]]) {device = "/device:TPU:1"}
// CHECK-DAG: %[[COPY3:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#1, %[[LAYOUT1]]) {device = "/device:TPU:1"}
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
// CHECK: tf_device.replicate([%[[COPY0]], %[[COPY2]]] as %[[R0:.*]]: tensor<3x3x1x32xf32>, [%[[COPY1]], %[[COPY3]]] as %[[R1:.*]]: tensor<3x3x1x32xf32>)
%5:2 = tf_device.replicate([%2#0, %3#0] as %r0: tensor<3x3x1x32xf32>, [%2#1, %3#1] as %r1: tensor<3x3x1x32xf32>)
{n = 2 : i32, devices = ["/device:TPU:0", "/device:TPU:1"]} {
// CHECK: "tf.TPUExecute"(%[[R0]], %[[R1]], %[[COMPILE]]#1)
%4 = "tf.TPUExecute"(%r0, %r1, %1#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %4 : tensor<i32>
}
return %5#0 : tensor<i32>
}
// -----
// Tests that the pass does not change inputs inside replicate.
// CHECK-LABEL: func @inside_replicated
func @inside_replicated(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> tensor<i32> {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
%5:2 = tf_device.replicate([%arg0, %arg1] as %r0: tensor<*x!tf.resource>)
{n = 2 : i32, devices = ["/device:TPU:0", "/device:TPU:1"]} {
%2:2 = "tf.IteratorGetNext"(%r0)
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
%4 = "tf.TPUExecute"(%2#0, %2#1, %1#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %4 : tensor<i32>
}
return %5#0 : tensor<i32>
}

View File

@ -55,6 +55,7 @@ void CreateTPUBridge(OpPassManager &pm) {
pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
pm.addPass(CreateTPURewritePass());
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
pm.addNestedPass<FuncOp>(CreateTPUDynamicLayoutPass());
pm.addNestedPass<FuncOp>(CreateTPUMergeVariablesWithExecutePass());
// TODO(b/147020076): Enable this pass.
// pm.addPass(CreateTPUVariableReformattingPass());

View File

@ -142,6 +142,10 @@ namespace TFTPU {
// `_tpu_replicate` attribute.
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUClusterFormationPass();
// Creates a pass that allows TPU program inputs to have layouts determined at
// run time.
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUDynamicLayoutPass();
// Creates a pass that remaps and assigns padding map from a
// `tf_device.launch_func` `padding_map` attribute to its encapsulated function.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUDynamicPaddingMapperPass();

View File

@ -0,0 +1,250 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kDeviceAttr[] = "device";
// A pass that allows TPU input layout to be determined after JIT compilation.
// This is done by adding run-time ops that interpret compilation result and
// copy the input to device with that layout.
//
// Example: original program:
//
// %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
// %compile:2 = "tf._TPUCompileMlir"(...)
// %execute = "tf.TPUExecute"(%input, ..., %compile#1) {device = "/TPU:0"}
//
// Without this pass, later TF graph partitioning passes will insert send/recv
// between %input and %execute and data will be copied to device in a fixed
// layout. With this pass, the program will be transformed into:
//
// %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
// %compile:2 = "tf._TPUCompileMlir"(...)
// %get_layout = "tf.TPUGetLayoutOp"(%compile#1) {...}
// %copy_to_device = "tf.TPUCopyWithLayout"(%input, %get_layout)
// {device = "/TPU:0"}
// %execute = "tf.TPUExecute"(%copy_to_device, ..., %compile#1)
// {device = "/TPU:0"}
//
// This way, %compile will determine the layout, which will be respected by
// %copy_to_device. There will not be send/recv ops added by later passes,
// because tf.TPUCopyWithLayout accepts a host input and produces a device
// output.
struct TPUDynamicLayoutPass : public FunctionPass<TPUDynamicLayoutPass> {
void runOnFunction() override;
};
// Checks if the input producer op is supported in this transform. Right now, we
// only check if it is a host tf.IteratorGetNext.
bool IsSupportedInputOp(Operation* op) {
if (!llvm::isa<TF::IteratorGetNextOp>(op)) return false;
auto device = op->getAttrOfType<StringAttr>(kDeviceAttr);
if (!device) return false;
tensorflow::DeviceNameUtils::ParsedName parsed_device;
if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
&parsed_device)) {
return false;
}
return parsed_device.type == "CPU";
}
// Builds a TPUGetLayoutOp with the given compile op and input index.
TF::TPUGetLayoutOp BuildGetLayout(Operation* compile, int64_t index,
OpBuilder* builder) {
builder->setInsertionPointAfter(compile);
return builder->create<TF::TPUGetLayoutOp>(
compile->getLoc(),
llvm::ArrayRef<Type>{
RankedTensorType::get({-1}, builder->getIntegerType(64))},
llvm::ArrayRef<Value>{compile->getResult(1)},
llvm::ArrayRef<NamedAttribute>{
builder->getNamedAttr("index", builder->getI64IntegerAttr(index)),
builder->getNamedAttr("is_output", builder->getBoolAttr(false))});
}
// Builds a TPUCopyWithLayoutOp with the given get_layout op and input.
// walk_order for ops in the original IR is needed because we need to insert the
// ops after both get_layout and input, so we use the walk order to find which
// one comes later.
TF::TPUCopyWithLayoutOp BuildCopyWithLayout(
TF::TPUExecuteOp execute, Operation* compile, TF::TPUGetLayoutOp get_layout,
Value input, const llvm::SmallDenseMap<Operation*, int64_t>& walk_order,
OpBuilder* builder) {
auto input_op = input.getDefiningOp();
int64_t compile_walk_order = walk_order.find(compile)->getSecond();
int64_t input_walk_order = walk_order.find(input_op)->getSecond();
if (compile_walk_order > input_walk_order) {
builder->setInsertionPointAfter(get_layout);
} else {
builder->setInsertionPointAfter(input_op);
}
return builder->create<TF::TPUCopyWithLayoutOp>(
execute.getLoc(), llvm::ArrayRef<Type>{input.getType()},
llvm::ArrayRef<Value>{input, get_layout.layout()},
llvm::ArrayRef<NamedAttribute>{});
}
// Performs transformation for a non-replicated input.
void HandleInput(Value input, int64_t index, TF::TPUExecuteOp execute,
Operation* compile,
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
OpBuilder builder(compile->getContext());
auto get_layout = BuildGetLayout(compile, index, &builder);
auto copy_with_layout = BuildCopyWithLayout(execute, compile, get_layout,
input, walk_order, &builder);
if (auto device = execute.getAttrOfType<StringAttr>(kDeviceAttr)) {
copy_with_layout.setAttr(kDeviceAttr, device);
}
execute.setOperand(index, copy_with_layout);
}
// Performs transformation for replicated inputs. Returns true if this is a
// supported case (thus transform happened).
bool HandleReplicatedInputs(
int64_t index, TF::TPUExecuteOp execute, Operation* compile,
int64_t replicate_arg_index, tf_device::ReplicateOp replicate,
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
// We need to know the devices to copy to.
if (!replicate.devices()) return false;
int64_t num_replicas = replicate.n().getZExtValue();
auto inputs = replicate.getOperands()
.drop_front(replicate_arg_index * num_replicas)
.take_front(num_replicas);
for (auto entry : llvm::enumerate(inputs)) {
auto input_op = entry.value().getDefiningOp();
if (!input_op || !IsSupportedInputOp(input_op)) return false;
}
OpBuilder builder(execute.getContext());
auto get_layout = BuildGetLayout(compile, index, &builder);
for (auto entry : llvm::enumerate(inputs)) {
auto copy_with_layout = BuildCopyWithLayout(
execute, compile, get_layout, entry.value(), walk_order, &builder);
copy_with_layout.setAttr(kDeviceAttr,
replicate.devices()->getValue()[entry.index()]);
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
copy_with_layout);
}
return true;
}
// Performs transformation on a pair of execute and compile ops. The compile
// should not have other uses.
void HandleExecute(TF::TPUExecuteOp execute, Operation* compile,
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
auto maybe_replicate = execute.getParentOfType<tf_device::ReplicateOp>();
llvm::SmallVector<int64_t, 8> unrestricted_input_indices;
for (auto input : llvm::enumerate(execute.args())) {
if (auto block_arg = input.value().dyn_cast<BlockArgument>()) {
// For a block argument, consider transforms only when it is a replicated
// input (defining ops will be outside the replicate node).
if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
!HandleReplicatedInputs(input.index(), execute, compile,
block_arg.getArgNumber(), maybe_replicate,
walk_order)) {
continue;
}
} else {
// For an op output, consider transforms only when 1) there is no
// replicateion or 2) it is outside the replicate node that encloses the
// execute node. (Because if the op is inside replicate, it is probably
// not on the host.)
auto input_op = input.value().getDefiningOp();
if (maybe_replicate &&
maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
continue;
}
if (!IsSupportedInputOp(input_op)) continue;
HandleInput(input.value(), input.index(), execute, compile, walk_order);
}
unrestricted_input_indices.push_back(input.index());
}
if (unrestricted_input_indices.empty()) return;
// Update the compilation metadata if we changed anything.
auto metadata_attr = compile->getAttrOfType<StringAttr>("metadata");
assert(metadata_attr && "Missing compilation metadata");
tensorflow::tpu::TPUCompileMetadataProto metadata;
metadata.ParseFromString(std::string(metadata_attr.getValue()));
for (int64_t input_index : unrestricted_input_indices) {
metadata.mutable_args(input_index)->set_unrestricted_layout(true);
}
compile->setAttr("metadata", OpBuilder(compile).getStringAttr(
metadata.SerializeAsString()));
}
void TPUDynamicLayoutPass::runOnFunction() {
llvm::SmallVector<std::pair<TF::TPUExecuteOp, Operation*>, 4>
executes_and_compiles;
llvm::SmallDenseMap<Operation*, int64_t> walk_order;
int64_t next_walk_order = 0;
getFunction().walk([&](Operation* op) {
walk_order[op] = next_walk_order++;
// Detect tf._TPUCompileMlir -> tf.TPUExecute
auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(op);
if (!execute) return;
auto compile = execute.key().getDefiningOp();
if (!compile || compile->getName().getStringRef() != "tf._TPUCompileMlir" ||
!compile->getResult(1).hasOneUse()) {
return;
}
executes_and_compiles.emplace_back(execute, compile);
});
for (auto execute_and_compile : executes_and_compiles) {
HandleExecute(execute_and_compile.first, execute_and_compile.second,
walk_order);
}
}
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUDynamicLayoutPass() {
return std::make_unique<TPUDynamicLayoutPass>();
}
static PassRegistration<TPUDynamicLayoutPass> pass(
"tf-tpu-dynamic-layout-pass",
"Adds ops that allow TPU program inputs to have layouts determined at JIT "
"compile time.");
} // namespace TFTPU
} // namespace mlir