[MLIR:TF/XLA] TPU dynamic layout pass
A pass that allows TPU input layout to be determined after JIT compilation. This is done by adding run-time ops that interpret compilation result and copy to device with that layout. PiperOrigin-RevId: 292970557 Change-Id: If453686abea1618ba6f88016912ea15e0f0b3d4e
This commit is contained in:
parent
874358298b
commit
ab94de29f5
@ -256,6 +256,7 @@ cc_library(
|
||||
"transforms/sink_constant.cc",
|
||||
"transforms/test_side_effect_analysis.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
"transforms/tpu_dynamic_layout_pass.cc",
|
||||
"transforms/tpu_dynamic_padding_mapper.cc",
|
||||
"transforms/tpu_merge_variables_with_execute.cc",
|
||||
"transforms/tpu_rewrite_pass.cc",
|
||||
|
@ -6386,6 +6386,25 @@ occurred during compilation.
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TPUCopyWithLayoutOp : TF_Op<"TPUCopyWithLayout", [NoSideEffect]> {
|
||||
let summary = "Op that copies host tensor to device with specified layout.";
|
||||
|
||||
let description = [{
|
||||
For internal use only.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
I64Tensor:$layout
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TPUExecuteOp : TF_Op<"TPUExecute", []> {
|
||||
let summary = "Op that loads and executes a TPU program on a TPU device.";
|
||||
|
||||
@ -6437,6 +6456,27 @@ output. For the internal use of the distributed TPU compiler.
|
||||
TF_DerivedResultTypeListAttr Tresults = TF_DerivedResultTypeListAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TPUGetLayoutOp : TF_Op<"TPUGetLayoutOp", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Op that retrieves the layout of an input or output determined by TPUCompile.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
For internal use only.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_StrTensor:$cache_key,
|
||||
|
||||
I64Attr:$index,
|
||||
BoolAttr:$is_output
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I64Tensor:$layout
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TPUReplicatedInputOp : TF_Op<"TPUReplicatedInput", [NoSideEffect]> {
|
||||
let summary = "Connects N inputs to an N-way replicated TPU computation.";
|
||||
|
||||
|
@ -0,0 +1,150 @@
|
||||
// RUN: tf-opt %s -split-input-file -tf-tpu-dynamic-layout-pass | FileCheck %s --dump-input=fail
|
||||
|
||||
// Tests that the pass can transform non-replicated execution.
|
||||
|
||||
// CHECK: func @non_replicated(%[[ARG0:.*]]: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32>
|
||||
func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf._TPUCompileMlir"()
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
|
||||
// CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
|
||||
// CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext"
|
||||
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
// CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
|
||||
// CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"}
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
// CHECK: "tf.TPUExecute"(%[[COPY0]], %[[COPY1]], %[[COMPILE]]#1) {device = "/device:TPU:0"}
|
||||
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
return %3 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass does not transform two execute ops sharing the same
|
||||
// compile op.
|
||||
|
||||
// CHECK-LABEL: func @multiple_compile_uses
|
||||
func @multiple_compile_uses(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-NOT: "tf.TPUGetLayoutOp"
|
||||
// CHECK-NOT: "tf.TPUCopyWithLayout"
|
||||
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
%4:2 = "tf._UnKnownOp_"() : () -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
%5 = "tf.TPUExecute"(%4#0, %4#1, %1#1) {device = "/device:TPU:0"}
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
return %5 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass does not transform when tf.IteratorGetNext is on TPU.
|
||||
|
||||
// CHECK-LABEL: func @on_tpu_iter
|
||||
func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) -> tensor<i32> {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-NOT: "tf.TPUGetLayoutOp"
|
||||
// CHECK-NOT: "tf.TPUCopyWithLayout"
|
||||
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:TPU:0"}
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
return %3 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass does not change unsupported input ops.
|
||||
|
||||
// CHECK-LABEL: func @unsupported_ops
|
||||
func @unsupported_ops(%arg0: tensor<3x3x1x32xf32>) -> tensor<i32> {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-NOT: "tf.TPUGetLayoutOp"
|
||||
// CHECK-NOT: "tf.TPUCopyWithLayout"
|
||||
%2 = "tf._Unknown_"() : () -> tensor<3x3x1x32xf32>
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
%3 = "tf.TPUExecute"(%arg0, %2, %1#1) {device = "/device:TPU:0"}
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
return %3 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass can transform replicated execution.
|
||||
|
||||
// CHECK: func @replicated(%[[ARG0:.*]]: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32>
|
||||
func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
|
||||
// CHECK: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"
|
||||
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf._TPUCompileMlir"()
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
|
||||
// CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
|
||||
// CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
|
||||
// CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"}
|
||||
// CHECK: %[[ITER1:.*]]:2 = "tf.IteratorGetNext"
|
||||
%3:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
// CHECK-DAG: %[[COPY2:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#0, %[[LAYOUT0]]) {device = "/device:TPU:1"}
|
||||
// CHECK-DAG: %[[COPY3:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#1, %[[LAYOUT1]]) {device = "/device:TPU:1"}
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
// CHECK: tf_device.replicate([%[[COPY0]], %[[COPY2]]] as %[[R0:.*]]: tensor<3x3x1x32xf32>, [%[[COPY1]], %[[COPY3]]] as %[[R1:.*]]: tensor<3x3x1x32xf32>)
|
||||
%5:2 = tf_device.replicate([%2#0, %3#0] as %r0: tensor<3x3x1x32xf32>, [%2#1, %3#1] as %r1: tensor<3x3x1x32xf32>)
|
||||
{n = 2 : i32, devices = ["/device:TPU:0", "/device:TPU:1"]} {
|
||||
// CHECK: "tf.TPUExecute"(%[[R0]], %[[R1]], %[[COMPILE]]#1)
|
||||
%4 = "tf.TPUExecute"(%r0, %r1, %1#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %4 : tensor<i32>
|
||||
}
|
||||
return %5#0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass does not change inputs inside replicate.
|
||||
|
||||
// CHECK-LABEL: func @inside_replicated
|
||||
func @inside_replicated(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> tensor<i32> {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-NOT: "tf.TPUGetLayoutOp"
|
||||
// CHECK-NOT: "tf.TPUCopyWithLayout"
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
%5:2 = tf_device.replicate([%arg0, %arg1] as %r0: tensor<*x!tf.resource>)
|
||||
{n = 2 : i32, devices = ["/device:TPU:0", "/device:TPU:1"]} {
|
||||
%2:2 = "tf.IteratorGetNext"(%r0)
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
%4 = "tf.TPUExecute"(%2#0, %2#1, %1#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %4 : tensor<i32>
|
||||
}
|
||||
return %5#0 : tensor<i32>
|
||||
}
|
@ -55,6 +55,7 @@ void CreateTPUBridge(OpPassManager &pm) {
|
||||
pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass());
|
||||
pm.addPass(CreateTPURewritePass());
|
||||
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
|
||||
pm.addNestedPass<FuncOp>(CreateTPUDynamicLayoutPass());
|
||||
pm.addNestedPass<FuncOp>(CreateTPUMergeVariablesWithExecutePass());
|
||||
// TODO(b/147020076): Enable this pass.
|
||||
// pm.addPass(CreateTPUVariableReformattingPass());
|
||||
|
@ -142,6 +142,10 @@ namespace TFTPU {
|
||||
// `_tpu_replicate` attribute.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUClusterFormationPass();
|
||||
|
||||
// Creates a pass that allows TPU program inputs to have layouts determined at
|
||||
// run time.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUDynamicLayoutPass();
|
||||
|
||||
// Creates a pass that remaps and assigns padding map from a
|
||||
// `tf_device.launch_func` `padding_map` attribute to its encapsulated function.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
|
||||
|
@ -0,0 +1,250 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
|
||||
// A pass that allows TPU input layout to be determined after JIT compilation.
|
||||
// This is done by adding run-time ops that interpret compilation result and
|
||||
// copy the input to device with that layout.
|
||||
//
|
||||
// Example: original program:
|
||||
//
|
||||
// %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
|
||||
// %compile:2 = "tf._TPUCompileMlir"(...)
|
||||
// %execute = "tf.TPUExecute"(%input, ..., %compile#1) {device = "/TPU:0"}
|
||||
//
|
||||
// Without this pass, later TF graph partitioning passes will insert send/recv
|
||||
// between %input and %execute and data will be copied to device in a fixed
|
||||
// layout. With this pass, the program will be transformed into:
|
||||
//
|
||||
// %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
|
||||
// %compile:2 = "tf._TPUCompileMlir"(...)
|
||||
// %get_layout = "tf.TPUGetLayoutOp"(%compile#1) {...}
|
||||
// %copy_to_device = "tf.TPUCopyWithLayout"(%input, %get_layout)
|
||||
// {device = "/TPU:0"}
|
||||
// %execute = "tf.TPUExecute"(%copy_to_device, ..., %compile#1)
|
||||
// {device = "/TPU:0"}
|
||||
//
|
||||
// This way, %compile will determine the layout, which will be respected by
|
||||
// %copy_to_device. There will not be send/recv ops added by later passes,
|
||||
// because tf.TPUCopyWithLayout accepts a host input and produces a device
|
||||
// output.
|
||||
struct TPUDynamicLayoutPass : public FunctionPass<TPUDynamicLayoutPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Checks if the input producer op is supported in this transform. Right now, we
|
||||
// only check if it is a host tf.IteratorGetNext.
|
||||
bool IsSupportedInputOp(Operation* op) {
|
||||
if (!llvm::isa<TF::IteratorGetNextOp>(op)) return false;
|
||||
auto device = op->getAttrOfType<StringAttr>(kDeviceAttr);
|
||||
if (!device) return false;
|
||||
tensorflow::DeviceNameUtils::ParsedName parsed_device;
|
||||
if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
|
||||
&parsed_device)) {
|
||||
return false;
|
||||
}
|
||||
return parsed_device.type == "CPU";
|
||||
}
|
||||
|
||||
// Builds a TPUGetLayoutOp with the given compile op and input index.
|
||||
TF::TPUGetLayoutOp BuildGetLayout(Operation* compile, int64_t index,
|
||||
OpBuilder* builder) {
|
||||
builder->setInsertionPointAfter(compile);
|
||||
return builder->create<TF::TPUGetLayoutOp>(
|
||||
compile->getLoc(),
|
||||
llvm::ArrayRef<Type>{
|
||||
RankedTensorType::get({-1}, builder->getIntegerType(64))},
|
||||
llvm::ArrayRef<Value>{compile->getResult(1)},
|
||||
llvm::ArrayRef<NamedAttribute>{
|
||||
builder->getNamedAttr("index", builder->getI64IntegerAttr(index)),
|
||||
builder->getNamedAttr("is_output", builder->getBoolAttr(false))});
|
||||
}
|
||||
|
||||
// Builds a TPUCopyWithLayoutOp with the given get_layout op and input.
|
||||
// walk_order for ops in the original IR is needed because we need to insert the
|
||||
// ops after both get_layout and input, so we use the walk order to find which
|
||||
// one comes later.
|
||||
TF::TPUCopyWithLayoutOp BuildCopyWithLayout(
|
||||
TF::TPUExecuteOp execute, Operation* compile, TF::TPUGetLayoutOp get_layout,
|
||||
Value input, const llvm::SmallDenseMap<Operation*, int64_t>& walk_order,
|
||||
OpBuilder* builder) {
|
||||
auto input_op = input.getDefiningOp();
|
||||
int64_t compile_walk_order = walk_order.find(compile)->getSecond();
|
||||
int64_t input_walk_order = walk_order.find(input_op)->getSecond();
|
||||
if (compile_walk_order > input_walk_order) {
|
||||
builder->setInsertionPointAfter(get_layout);
|
||||
} else {
|
||||
builder->setInsertionPointAfter(input_op);
|
||||
}
|
||||
return builder->create<TF::TPUCopyWithLayoutOp>(
|
||||
execute.getLoc(), llvm::ArrayRef<Type>{input.getType()},
|
||||
llvm::ArrayRef<Value>{input, get_layout.layout()},
|
||||
llvm::ArrayRef<NamedAttribute>{});
|
||||
}
|
||||
|
||||
// Performs transformation for a non-replicated input.
|
||||
void HandleInput(Value input, int64_t index, TF::TPUExecuteOp execute,
|
||||
Operation* compile,
|
||||
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
|
||||
OpBuilder builder(compile->getContext());
|
||||
auto get_layout = BuildGetLayout(compile, index, &builder);
|
||||
auto copy_with_layout = BuildCopyWithLayout(execute, compile, get_layout,
|
||||
input, walk_order, &builder);
|
||||
if (auto device = execute.getAttrOfType<StringAttr>(kDeviceAttr)) {
|
||||
copy_with_layout.setAttr(kDeviceAttr, device);
|
||||
}
|
||||
execute.setOperand(index, copy_with_layout);
|
||||
}
|
||||
|
||||
// Performs transformation for replicated inputs. Returns true if this is a
|
||||
// supported case (thus transform happened).
|
||||
bool HandleReplicatedInputs(
|
||||
int64_t index, TF::TPUExecuteOp execute, Operation* compile,
|
||||
int64_t replicate_arg_index, tf_device::ReplicateOp replicate,
|
||||
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
|
||||
// We need to know the devices to copy to.
|
||||
if (!replicate.devices()) return false;
|
||||
int64_t num_replicas = replicate.n().getZExtValue();
|
||||
auto inputs = replicate.getOperands()
|
||||
.drop_front(replicate_arg_index * num_replicas)
|
||||
.take_front(num_replicas);
|
||||
for (auto entry : llvm::enumerate(inputs)) {
|
||||
auto input_op = entry.value().getDefiningOp();
|
||||
if (!input_op || !IsSupportedInputOp(input_op)) return false;
|
||||
}
|
||||
OpBuilder builder(execute.getContext());
|
||||
auto get_layout = BuildGetLayout(compile, index, &builder);
|
||||
for (auto entry : llvm::enumerate(inputs)) {
|
||||
auto copy_with_layout = BuildCopyWithLayout(
|
||||
execute, compile, get_layout, entry.value(), walk_order, &builder);
|
||||
copy_with_layout.setAttr(kDeviceAttr,
|
||||
replicate.devices()->getValue()[entry.index()]);
|
||||
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
|
||||
copy_with_layout);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Performs transformation on a pair of execute and compile ops. The compile
|
||||
// should not have other uses.
|
||||
void HandleExecute(TF::TPUExecuteOp execute, Operation* compile,
|
||||
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
|
||||
auto maybe_replicate = execute.getParentOfType<tf_device::ReplicateOp>();
|
||||
llvm::SmallVector<int64_t, 8> unrestricted_input_indices;
|
||||
for (auto input : llvm::enumerate(execute.args())) {
|
||||
if (auto block_arg = input.value().dyn_cast<BlockArgument>()) {
|
||||
// For a block argument, consider transforms only when it is a replicated
|
||||
// input (defining ops will be outside the replicate node).
|
||||
if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
|
||||
!HandleReplicatedInputs(input.index(), execute, compile,
|
||||
block_arg.getArgNumber(), maybe_replicate,
|
||||
walk_order)) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// For an op output, consider transforms only when 1) there is no
|
||||
// replicateion or 2) it is outside the replicate node that encloses the
|
||||
// execute node. (Because if the op is inside replicate, it is probably
|
||||
// not on the host.)
|
||||
auto input_op = input.value().getDefiningOp();
|
||||
if (maybe_replicate &&
|
||||
maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
|
||||
continue;
|
||||
}
|
||||
if (!IsSupportedInputOp(input_op)) continue;
|
||||
HandleInput(input.value(), input.index(), execute, compile, walk_order);
|
||||
}
|
||||
unrestricted_input_indices.push_back(input.index());
|
||||
}
|
||||
if (unrestricted_input_indices.empty()) return;
|
||||
|
||||
// Update the compilation metadata if we changed anything.
|
||||
auto metadata_attr = compile->getAttrOfType<StringAttr>("metadata");
|
||||
assert(metadata_attr && "Missing compilation metadata");
|
||||
tensorflow::tpu::TPUCompileMetadataProto metadata;
|
||||
metadata.ParseFromString(std::string(metadata_attr.getValue()));
|
||||
for (int64_t input_index : unrestricted_input_indices) {
|
||||
metadata.mutable_args(input_index)->set_unrestricted_layout(true);
|
||||
}
|
||||
compile->setAttr("metadata", OpBuilder(compile).getStringAttr(
|
||||
metadata.SerializeAsString()));
|
||||
}
|
||||
|
||||
void TPUDynamicLayoutPass::runOnFunction() {
|
||||
llvm::SmallVector<std::pair<TF::TPUExecuteOp, Operation*>, 4>
|
||||
executes_and_compiles;
|
||||
llvm::SmallDenseMap<Operation*, int64_t> walk_order;
|
||||
int64_t next_walk_order = 0;
|
||||
getFunction().walk([&](Operation* op) {
|
||||
walk_order[op] = next_walk_order++;
|
||||
// Detect tf._TPUCompileMlir -> tf.TPUExecute
|
||||
auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(op);
|
||||
if (!execute) return;
|
||||
auto compile = execute.key().getDefiningOp();
|
||||
if (!compile || compile->getName().getStringRef() != "tf._TPUCompileMlir" ||
|
||||
!compile->getResult(1).hasOneUse()) {
|
||||
return;
|
||||
}
|
||||
executes_and_compiles.emplace_back(execute, compile);
|
||||
});
|
||||
for (auto execute_and_compile : executes_and_compiles) {
|
||||
HandleExecute(execute_and_compile.first, execute_and_compile.second,
|
||||
walk_order);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTPUDynamicLayoutPass() {
|
||||
return std::make_unique<TPUDynamicLayoutPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPUDynamicLayoutPass> pass(
|
||||
"tf-tpu-dynamic-layout-pass",
|
||||
"Adds ops that allow TPU program inputs to have layouts determined at JIT "
|
||||
"compile time.");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
Loading…
Reference in New Issue
Block a user