[MLIR:TF] TPU space to depth pass.

Add pass that does space to depth transformation at compile time for convolution that incur excessive TPU padding.

PiperOrigin-RevId: 314254011
Change-Id: Id4e9e46a954e05e17a38023bf34615773705aeba
This commit is contained in:
A. Unique TensorFlower 2020-06-01 19:54:35 -07:00 committed by TensorFlower Gardener
parent 918731364a
commit df7fd4acda
3 changed files with 791 additions and 0 deletions

View File

@ -452,6 +452,7 @@ cc_library(
"transforms/tpu_merge_variables_with_execute.cc",
"transforms/tpu_rewrite_pass.cc",
"transforms/tpu_sharding_identification_pass.cc",
"transforms/tpu_space_to_depth_pass.cc",
"transforms/tpu_variable_runtime_reformatting.cc",
"translate/breakup-islands.cc",
"translate/control_to_executor_dialect.cc",

View File

@ -0,0 +1,87 @@
// RUN: tf-opt %s -split-input-file -tf-tpu-space-to-depth-pass | FileCheck %s --dump-input=fail
// Tests for space to depth host and device transform.
module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:1" = {}, "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 390 : i32}} {
func @main(%arg0: tensor<!tf.resource> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor<!tf.variant> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg2: tensor<!tf.resource> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor<!tf.variant> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg4: tensor<!tf.resource<tensor<7x7x3x64xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg5: tensor<!tf.resource<tensor<f32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg6: tensor<!tf.resource<tensor<f32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg7: tensor<!tf.resource<tensor<i64>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) attributes {tf.entry_function = {control_outputs = "while", inputs = "iterator,iterator_1,iterator_2,iterator_3,while_input_6,while_input_7,while_input_8,while_input_9", outputs = ""}} {
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%3:10 = "tf.While"(%2, %1, %2, %0, %1, %arg2, %arg4, %arg5, %arg6, %arg7) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_2710, cond = @while_cond_2700, device = "", is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>)
return
}
// CHECK-LABEL: func @while_body_2710
func @while_body_2710(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<!tf.resource> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor<!tf.resource<tensor<7x7x3x64xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg7: tensor<!tf.resource<tensor<f32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg8: tensor<!tf.resource<tensor<f32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg9: tensor<!tf.resource<tensor<i64>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>) attributes {tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[INPUT:.*]] = "tf.IteratorGetNext"
%1 = "tf.IteratorGetNext"(%arg5) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<!tf.resource>) -> tensor<2x224x224x3xf32>
// CHECK-DAG: %[[SPACETODEPTH0:.*]] = "tf.SpaceToDepth"([[INPUT:.*]]) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
%2 = "tf.AddV2"(%arg2, %arg3) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%3 = "tf.ReadVariableOp"(%arg6) : (tensor<!tf.resource<tensor<7x7x3x64xf32>>>) -> tensor<7x7x3x64xf32>
%4 = "tf.ReadVariableOp"(%arg8) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
%5 = "tf.ReadVariableOp"(%arg7) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
%6 = "tf.ReadVariableOp"(%arg9) : (tensor<!tf.resource<tensor<i64>>>) -> tensor<i64>
%7:2 = "tf_device.cluster_func"(%1, %3, %5, %6) {_tpu_replicate = "while/cluster_while_body_271", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0], func = @_func, host_compute_core = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], num_cores_per_replica = 1 : i64, output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", use_tpu = true} : (tensor<2x224x224x3xf32>, tensor<7x7x3x64xf32>, tensor<f32>, tensor<i64>) -> (tensor<7x7x3x64xf32>, tensor<i64>)
"tf.AssignVariableOp"(%arg6, %7#0) : (tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<7x7x3x64xf32>) -> ()
"tf.AssignVariableOp"(%arg9, %7#1) : (tensor<!tf.resource<tensor<i64>>>, tensor<i64>) -> ()
%8 = "tf.Identity"(%arg1) {device = ""} : (tensor<i32>) -> tensor<i32>
%9 = "tf.Identity"(%2) {device = ""} : (tensor<i32>) -> tensor<i32>
%10 = "tf.AddV2"(%arg0, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%11 = "tf.Identity"(%10) {device = ""} : (tensor<i32>) -> tensor<i32>
return %11, %8, %9, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.resource>, tensor<!tf.resource<tensor<7x7x3x64xf32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<i64>>>
}
func @while_cond_2700(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<!tf.resource> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor<!tf.resource<tensor<7x7x3x64xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg7: tensor<!tf.resource<tensor<f32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg8: tensor<!tf.resource<tensor<f32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg9: tensor<!tf.resource<tensor<i64>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor<i1> {
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.GreaterEqual"(%arg3, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%2 = "tf.Less"(%arg3, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = "tf.Greater"(%arg2, %arg4) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%4 = "tf.LogicalAnd"(%2, %3) {device = ""} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%5 = "tf.Less"(%arg2, %arg4) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%6 = "tf.LogicalAnd"(%1, %5) {device = ""} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%7 = "tf.LogicalOr"(%6, %4) {device = ""} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%8 = "tf.Less"(%arg0, %arg1) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%9 = "tf.LogicalAnd"(%8, %7) {device = ""} : (tensor<i1>, tensor<i1>) -> tensor<i1>
%10 = "tf.Identity"(%9) {device = ""} : (tensor<i1>) -> tensor<i1>
return %10 : tensor<i1>
}
// CHECK-LABEL: func @_func
// CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor<f32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
func @_func(%arg0: tensor<2x224x224x3xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<f32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
%0 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
%1 = "tf.Const"() {value = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
%2 = "tf.Const"() {value = dense<[7, 7, 3, 64]> : tensor<4xi32>} : () -> tensor<4xi32>
%3 = "tf.Const"() {value = dense<[[0, 0], [3, 3], [3, 3], [0, 0]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
%4 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%5 = "tf.Pad"(%arg0, %3) : (tensor<2x224x224x3xf32>, tensor<4x2xi32>) -> tensor<2x230x230x3xf32>
// CHECK: "tf.Conv2D"
// CHECK-SAME: strides = [1, 1, 1, 1]
// CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4x4x12x64xf32>) -> tensor<2x112x112x64xf32>
%6 = "tf.Conv2D"(%5, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> tensor<2x112x112x64xf32>
// CHECK: %[[BACKPROP:.*]] = "tf.Conv2DBackpropFilter"
// CHECK-SAME: strides = [1, 1, 1, 1]
// CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4xi32>, tensor<2x112x112x64xf32>) -> tensor<4x4x12x64xf32>
%7 = "tf.Conv2DBackpropFilter"(%5, %2, %6) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<4xi32>, tensor<2x112x112x64xf32>) -> tensor<7x7x3x64xf32>
// CHECK: %[[CONST0:.*]] = "tf.Const"() {value = dense<
// CHECK-SAME: [4, 4, 2, 2, 3, 64]
// CHECK: %[[RESHAPE0:.*]] = "tf.Reshape"(%[[BACKPROP:.*]], %[[CONST0:.*]]) : (tensor<4x4x12x64xf32>, tensor<6xi64>) -> tensor<4x4x2x2x3x64xf32>
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<
// CHECK-SAME: [0, 2, 1, 3, 4, 5]
// CHECK: %[[TRANSPOSE0:.*]] = "tf.Transpose"(%[[RESHAPE0:.*]], %[[CONST1:.*]]) : (tensor<4x4x2x2x3x64xf32>, tensor<6xi32>) -> tensor<4x2x4x2x3x64xf32>
// CHECK: %[[CONST2:.*]] = "tf.Const"() {value = dense<
// CHECK-SAME: [8, 8, 3, 64]
// CHECK: %[[RESHAPE1:.*]] = "tf.Reshape"(%[[TRANSPOSE1:.*]], %[[CONST2:.*]]) : (tensor<4x2x4x2x3x64xf32>, tensor<4xi64>) -> tensor<8x8x3x64xf32>
// CHECK: %[[CONST3:.*]] = "tf.Const"() {value = dense<
// CHECK-SAME: [7, 7, 3, 64]
// CHECK: %[[CONST4:.*]] = "tf.Const"() {value = dense<
// CHECK-SAME: 0
// CHECK: %[[SLICE0:.*]] = "tf.Slice"(%[[RESHAPE1:.*]], %[[CONST4:.*]], %[[CONST3:.*]]) : (tensor<8x8x3x64xf32>, tensor<4xi64>, tensor<4xi32>) -> tensor<7x7x3x64xf32>
%8 = "tf.CrossReplicaSum"(%7, %1) : (tensor<7x7x3x64xf32>, tensor<1x1xi32>) -> tensor<7x7x3x64xf32>
%9 = "tf.Mul"(%arg2, %8) : (tensor<f32>, tensor<7x7x3x64xf32>) -> tensor<7x7x3x64xf32>
%10 = "tf.Sub"(%arg1, %9) : (tensor<7x7x3x64xf32>, tensor<7x7x3x64xf32>) -> tensor<7x7x3x64xf32>
%11 = "tf.AddV2"(%arg3, %0) : (tensor<i64>, tensor<i64>) -> tensor<i64>
return %10, %11 : tensor<7x7x3x64xf32>, tensor<i64>
}
}
// ----

View File

@ -0,0 +1,703 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <iostream>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kDeviceAttr[] = "device";
typedef std::pair<TF::Conv2DOp, int64_t> Conv2DWithBlockSize;
// A pass that applies automatic space to depth transform for the first or
// frontier convolutions consume host inputs on TPU.
// This is done by adding space to depth transform op after host input and
// applying space to depth transform for the first convolution and its backprop
// filter on TPU.
//
// Example: original program:
//
// module {
// func @while_body {
// %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}:
// -> tensor<2x224x224x3xf32>
// %device_launch = "tf_device.cluster_func"(%input,...) {func = @_func,...)
// return ...
// }
// func @_func(%input: tensor<2x224x224x3xf32>,
// %filter: tensor<7x7x3x64xf32>) {
// %6 = "tf.Conv2D"(%input, %filter) {strides = [1, 2, 2, 1]}:
// (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) ->
// tensor<2x112x112x64xf32>
// }
// }
//
// With this pass, the program will be transformed into:
// module {
// func @while_body {
// %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
// -> tensor<2x224x224x3xf32>
// %space_to_depth = "tf.SpaceToDepth"(%input) {block_size = 2, ...}:
// (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
// %device_launch = "tf_device.cluster_func"(%space_to_depth,...)
// {func = @_func,...)
// return ...
// }
// func @_func(%input: tensor<2x112x112x12xf32>,
// %filter: tensor<7x7x3x64xf32>) {
// %filter_transform = "tf.Pad/tf.Transpose/tf.Reshape"(%filter):
// tensor<7x7x3x64xf32>) -> tensor<4x4x12x64xf32>
// %conv = "tf.Conv2D"(%input, %filter_transfrom) {strides = [1, 1, 1, 1]}:
// (tensor<2x112x112x12xf32>, tensor<4x4x12x64xf32>) ->
// tensor<2x112x112x64xf32>
// }
// }
//
// This way, the first convolution with 3 feature dimension will be transformed
// to 12 feature dimension, which has better performance on TPU.
//
// TODO(wangtao): add a pass to check if it is profitable to space to depth
// transform and invoke the transform if it is needed.
struct TPUSpaceToDepthPass
: public PassWrapper<TPUSpaceToDepthPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
// Handle padding before convolution for space to depth transform.
LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
auto ranked_type = op.input().getType().dyn_cast<RankedTensorType>();
if (!ranked_type) return failure();
auto pad_input_shape = ranked_type.getShape();
Location loc = op.getLoc();
OpBuilder builder(op);
builder.setInsertionPoint(op);
auto padding_type = RankedTensorType::get({4, 2}, builder.getIntegerType(32));
// Calculate paddings.
int32_t pad_total = kernel_size - 1;
int32_t pad_beg = (pad_total / 2 + 1) / block_size;
int32_t pad_end = (pad_total / 2) / block_size;
SmallVector<int32_t, 8> values = {0, 0, pad_beg, pad_end,
pad_beg, pad_end, 0, 0};
auto paddings = DenseIntElementsAttr::get(padding_type, values);
// Update pad_op paddings.
op.setOperand(1, builder.create<TF::ConstOp>(loc, paddings));
// Set input type.
auto input = op.getOperand(0);
SmallVector<int64_t, 4> transform_shape = {
pad_input_shape[0], pad_input_shape[1] / block_size,
pad_input_shape[2] / block_size,
pad_input_shape[3] * block_size * block_size};
auto transform_result_type =
RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
input.setType(transform_result_type);
op.setOperand(0, input);
return success();
}
// Handle stride for the first convolution for the transform.
void HandleConv2DStride(TF::Conv2DOp conv2d) {
MLIRContext* context = conv2d.getContext();
SmallVector<int64_t, 4> values = {1, 1, 1, 1};
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
return IntegerAttr::get(IntegerType::get(64, context), v);
});
// TODO(b/157276506): change type of strides to DenseElementsAttr
auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
conv2d.setAttr("strides", strides);
}
// Transform input shape for the first convolution.
void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
auto input = conv2d.input();
auto input_shape = input.getType().cast<RankedTensorType>().getShape();
SmallVector<int64_t, 4> transform_shape = {
input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
input_shape[3] * block_size * block_size};
auto transform_result_type =
RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
input.setType(transform_result_type);
}
// Add padding for convolution filter for space to depth transform.
TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
OpBuilder* builder, int32_t pad_h,
int32_t pad_w) {
SmallVector<int32_t, 8> values = {pad_h, 0, pad_w, 0, 0, 0, 0, 0};
auto padding_type =
RankedTensorType::get({4, 2}, builder->getIntegerType(32));
auto paddings = DenseIntElementsAttr::get(padding_type, values);
auto paddings_value = builder->create<TF::ConstOp>(filter.getLoc(), paddings);
std::vector<int64_t> pad_shape = {filter_shape[0] + pad_h,
filter_shape[1] + pad_w, filter_shape[2],
filter_shape[3]};
SmallVector<int64_t, 4> expand_shape(pad_shape.begin(), pad_shape.end());
auto expand_result_type =
RankedTensorType::get(expand_shape, getElementTypeOrSelf(filter));
return builder->create<TF::PadOp>(filter.getLoc(), expand_result_type, filter,
paddings_value);
}
// Create reshape op for space to depth transform.
TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
Value input, OpBuilder* builder) {
auto reshape_result_type =
RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
auto reshape_type = RankedTensorType::get(
{static_cast<int64_t>(new_shape.size())}, builder->getIntegerType(64));
auto reshape_sizes = DenseIntElementsAttr::get(reshape_type, new_shape);
auto reshape_value =
builder->create<TF::ConstOp>(input.getLoc(), reshape_sizes);
return builder->create<TF::ReshapeOp>(input.getLoc(), reshape_result_type,
input, reshape_value);
}
// Create transpose op for shape to depth transform.
TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) {
SmallVector<int32_t, 6> permutation = {0, 2, 1, 3, 4, 5};
auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32));
auto permute_attr = DenseIntElementsAttr::get(permute_type, permutation);
auto permute_value =
builder->create<TF::ConstOp>(input.getLoc(), permute_attr);
return builder->create<TF::TransposeOp>(input.getLoc(), input, permute_value);
}
void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) {
// For example, if filter shape is [7, 7, 3, 64] with block_size 2,
// will apply below transforms to the filter:
// 1. Pad the filter to [8, 8, 3, 64]
// 2. Reshape to [4, 2, 4, 2, 3, 64]
// 3. Transpose to [4, 4, 2, 2, 3, 64]
// 4. Reshape to [4, 4, 12, 64]
auto filter = conv2d.filter();
OpBuilder builder(conv2d);
builder.setInsertionPoint(conv2d);
// Book keeping filter information.
auto filter_shape = filter.getType().cast<RankedTensorType>().getShape();
int64_t height = filter_shape[0];
int64_t width = filter_shape[1];
int64_t channel = filter_shape[2];
int64_t out_channel = filter_shape[3];
// Value/Op before reshape op.
Value before_reshape_value = filter;
if (height % block_size != 0 || width % block_size != 0) {
// Calculate paddings for height and width.
int32_t pad_h = block_size - height % block_size;
int32_t pad_w = block_size - width % block_size;
auto pad_op =
GetPadOpForConv2DFilter(filter_shape, filter, &builder, pad_h, pad_w);
// Update op, height and width before reshape.
before_reshape_value = pad_op;
height = height + pad_h;
width = width + pad_w;
}
// Reshape.
SmallVector<int64_t, 6> new_shape = {
height / block_size, block_size, width / block_size,
block_size, channel, out_channel};
auto reshape_op =
GetReshapeOpForConv2DFilter(new_shape, before_reshape_value, &builder);
// Transpose.
auto transpose_op = GetTransposeOpForConv2DFilter(&builder, reshape_op);
// Reshape Back.
SmallVector<int64_t, 4> final_shape = {
height / block_size, width / block_size,
channel * block_size * block_size, out_channel};
auto final_reshape_op =
GetReshapeOpForConv2DFilter(final_shape, transpose_op, &builder);
// Update filter of Conv2D.
conv2d.setOperand(1, final_reshape_op);
}
// Create slice op for filter in back prop pass.
TF::SliceOp GetSliceOpForConv2DBackPropFilter(
ArrayRef<int32_t> old_filter_shape, Value input, OpBuilder* builder) {
SmallVector<int64_t, 4> slice_size(old_filter_shape.begin(),
old_filter_shape.end());
auto slice_result_type =
RankedTensorType::get(slice_size, getElementTypeOrSelf(input));
auto slice_size_op = builder->create<TF::ConstOp>(
input.getLoc(),
DenseIntElementsAttr::get(
RankedTensorType::get({4}, builder->getIntegerType(32)),
old_filter_shape));
SmallVector<int64_t, 4> slice_start_position = {0, 0, 0, 0};
auto start_position_type =
RankedTensorType::get({4}, builder->getIntegerType(64));
auto start_position = builder->create<TF::ConstOp>(
input.getLoc(),
DenseIntElementsAttr::get(start_position_type, slice_start_position));
return builder->create<TF::SliceOp>(input.getLoc(), slice_result_type, input,
start_position, slice_size_op);
}
// Transform Conv2DBackPropFilter for space to depth.
void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
ArrayRef<int32_t> old_filter_shape,
ArrayRef<int32_t> new_filter_shape,
int64_t block_size) {
OpBuilder builder(backprop);
builder.setInsertionPoint(backprop);
auto input = backprop.input();
// Get new filter size from new_filter_shape.
auto new_filter_sizes = builder.create<TF::ConstOp>(
backprop.getLoc(),
DenseIntElementsAttr::get(
RankedTensorType::get({4}, builder.getIntegerType(32)),
new_filter_shape));
// Set stride to [1, 1, 1, 1].
MLIRContext* context = backprop.getContext();
SmallVector<int64_t, 4> values = {1, 1, 1, 1};
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
});
auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
// new result type.
SmallVector<int64_t, 4> new_shape(new_filter_shape.begin(),
new_filter_shape.end());
auto new_result_type =
RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
// Build new BackPropFilterOp.
auto loc = backprop.getLoc();
auto new_backprop = builder.create<TF::Conv2DBackpropFilterOp>(
loc, new_result_type, input, new_filter_sizes, backprop.out_backprop(),
strides, backprop.use_cudnn_on_gpu(), backprop.padding(),
backprop.explicit_paddings(), backprop.data_format(),
backprop.dilations());
// For example, if new filter shape is [4, 4, 12, 64], old filter shape
// is [7, 7, 3, 64] with block_size 2.
// Below transforms will be applied to the filter:
// 1. Reshape to [4, 4, 2, 2, 3, 64];
// 2. Transpose to [4, 2, 4, 2, 3, 64];
// 3. Reshape to [8, 8, 3, 64];
// 4. Slice to [7, 7, 3, 64].
SmallVector<int64_t, 6> first_reshape_shape = {
new_filter_shape[0],
new_filter_shape[1],
block_size,
block_size,
new_filter_shape[2] / (block_size * block_size),
new_filter_shape[3]};
auto first_reshape_op =
GetReshapeOpForConv2DFilter(first_reshape_shape, new_backprop, &builder);
// Transpose.
auto transpose_op = GetTransposeOpForConv2DFilter(&builder, first_reshape_op);
// Last Reshape op.
SmallVector<int64_t, 4> last_reshape_shape = {
new_filter_shape[0] * block_size, new_filter_shape[1] * block_size,
new_filter_shape[2] / (block_size * block_size), new_filter_shape[3]};
auto final_reshape_op =
GetReshapeOpForConv2DFilter(last_reshape_shape, transpose_op, &builder);
// create slice op.
auto slice_op = GetSliceOpForConv2DBackPropFilter(old_filter_shape,
final_reshape_op, &builder);
// Update backprop's user with the slice op.
backprop.replaceAllUsesWith(slice_op.getResult());
}
// Update func arugument type to have the updated input shape.
void UpdateFuncType(FuncOp func) {
llvm::SmallVector<Type, 8> arg_types;
arg_types.reserve(func.getNumArguments());
for (auto arg : func.getArguments()) arg_types.emplace_back(arg.getType());
auto terminator = func.front().getTerminator();
SmallVector<Type, 4> result_types(terminator->operand_type_begin(),
terminator->operand_type_end());
func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
}
void HandleFuncOp(Operation* op) {
auto func = llvm::cast<FuncOp>(op);
UpdateFuncType(func);
}
// Checks if the input producer op is supported in this transform. Right now, we
// only check if it is a host tf.IteratorGetNext.
bool IsSupportedHostInputOp(Operation* op) {
TF::IteratorGetNextOp iter = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
if (!iter) return false;
auto device = op->getAttrOfType<StringAttr>(kDeviceAttr);
if (!device) return false;
tensorflow::DeviceNameUtils::ParsedName parsed_device;
if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
&parsed_device)) {
return false;
}
return parsed_device.type == "CPU";
}
// Builds a SpaceToDepthOp with the given get_layout op and input.
TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,
Value input, int32_t block_size,
ArrayRef<int64_t> input_shape) {
auto input_op = input.getDefiningOp();
OpBuilder builder(input_op);
builder.setInsertionPointAfter(input_op);
SmallVector<int64_t, 4> transform_shape = {
input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
input_shape[3] * block_size * block_size};
auto transform_result_type =
RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
return builder.create<TF::SpaceToDepthOp>(cluster_func.getLoc(),
transform_result_type, input,
APInt(64, block_size));
}
// Performs transformation for a non-replicated input.
TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index,
tf_device::ClusterFuncOp cluster_func,
int32_t block_size,
ArrayRef<int64_t> input_shape) {
auto space_to_depth =
BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
cluster_func.setOperand(index, space_to_depth);
return space_to_depth;
}
// Performs transformation for replicated inputs. Returns true if this is a
// supported case (thus transform happened).
bool HandleHostReplicatedInputs(int64_t index,
tf_device::ClusterFuncOp cluster_func,
int64_t replicate_arg_index,
tf_device::ReplicateOp replicate,
int32_t block_size) {
// We need to know the devices to copy to.
if (!replicate.devices()) return false;
int64_t num_replicas = replicate.n().getZExtValue();
// Gets inputs at replicate_arg_index for each replica.
auto inputs = replicate.getOperands()
.drop_front(replicate_arg_index * num_replicas)
.take_front(num_replicas);
for (auto input : inputs) {
auto input_op = input.getDefiningOp();
if (!input_op || !IsSupportedHostInputOp(input_op)) return false;
}
for (auto entry : llvm::enumerate(inputs)) {
auto ranked_type = entry.value().getType().dyn_cast<RankedTensorType>();
if (!ranked_type) return false;
auto input_shape = ranked_type.getShape();
auto space_to_depth =
BuildSpaceToDepth(cluster_func, entry.value(), block_size, input_shape);
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
space_to_depth);
}
return true;
}
// Performs transformation on a pair of execute and compile ops. The compile
// should not have other uses.
void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
unsigned arg_num) {
auto maybe_replicate =
llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func.getParentOp());
llvm::SmallVector<int64_t, 8> transform_input_indices;
for (auto input : llvm::enumerate(cluster_func.operands())) {
if (auto block_arg = input.value().dyn_cast<BlockArgument>()) {
if (block_arg.getArgNumber() != arg_num) continue;
// For a block argument, consider transforms only when it is a replicated
// input (defining ops will be outside the replicate node).
if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) {
HandleHostReplicatedInputs(input.index(), cluster_func,
block_arg.getArgNumber(), maybe_replicate,
block_size);
}
} else {
// For an op output, consider transforms only when 1) there is no
// replicateion or 2) it is outside the replicate node that encloses the
// execute node. (Because if the op is inside replicate, it is probably
// not on the host.)
if (input.index() != arg_num) continue;
auto input_op = input.value().getDefiningOp();
if (maybe_replicate &&
maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
continue;
}
if (!IsSupportedHostInputOp(input_op)) continue;
auto ranked_type = input.value().getType().dyn_cast<RankedTensorType>();
if (!ranked_type) continue;
auto input_shape = ranked_type.getShape();
HandleHostInput(input.value(), input.index(), cluster_func, block_size,
input_shape);
}
}
}
// Check if input shape of convolution is good for space to depth transform.
bool Conv2DInputShapeCanTransform(Value input) {
auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
if (!ranked_type) return false;
auto input_shape = ranked_type.getShape();
int32_t batch_size = input_shape[0];
int32_t channel = input_shape[3];
if (batch_size > 8 || channel > 8) {
return false;
}
return true;
}
// Checks if a convoluton can apply SpaceToDepth transform.
// Only the first convolution in the graph whose batch size smaller than 8
// and its input feature size smaller than 8 can be transformed.
Optional<std::pair<unsigned, int>> GetConv2DInputArgNum(TF::Conv2DOp conv2d) {
if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) {
return None;
}
auto conv2d_input = conv2d.input();
if (auto block_arg = conv2d_input.dyn_cast<mlir::BlockArgument>()) {
if (!Conv2DInputShapeCanTransform(conv2d_input)) return None;
int num_users =
std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end());
return std::make_pair(block_arg.getArgNumber(), num_users);
}
if (auto pad_op = llvm::dyn_cast<TF::PadOp>(conv2d_input.getDefiningOp())) {
auto pad_input = pad_op.input();
if (auto block_arg = pad_input.dyn_cast<mlir::BlockArgument>()) {
if (!Conv2DInputShapeCanTransform(pad_input)) return None;
int num_users = std::distance(block_arg.getUsers().begin(),
block_arg.getUsers().end());
return std::make_pair(block_arg.getArgNumber(), num_users);
}
}
return None;
}
// Apply space to depth transform for the first convolution on TPU device.
void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
// Check if input and filter type are RankedTensorType.
auto input_tensor_type =
conv2d.input().getType().dyn_cast<RankedTensorType>();
auto filter_tensor_type =
conv2d.filter().getType().dyn_cast<RankedTensorType>();
if (!input_tensor_type || !filter_tensor_type) return;
// Book keeping filter shape for padding and backprop filter rewrite.
auto filter_shape = filter_tensor_type.getShape();
SmallVector<int32_t, 4> old_filter_shape(filter_shape.begin(),
filter_shape.end());
// Handles input.
auto conv2d_input = conv2d.input();
if (auto block_arg = conv2d_input.dyn_cast<mlir::BlockArgument>()) {
// Change on device function type/shape.
HandleFuncOp(block_arg.getOwner()->getParentOp());
}
if (auto pad_op = dyn_cast_or_null<TF::PadOp>(conv2d_input.getDefiningOp())) {
// Rewrite pad_op before Convolutioin.
if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return;
auto pad_input = pad_op.input();
if (auto block_arg = pad_input.dyn_cast<mlir::BlockArgument>()) {
// Change on device function type/shape.
HandleFuncOp(block_arg.getOwner()->getParentOp());
}
}
// Handle Conv2D input, stride and filter.
HandleConv2DInput(conv2d, block_size);
HandleConv2DStride(conv2d);
HandleConv2DFilter(conv2d, block_size);
// Book keeping new filter shape for backprop filter rewrite.
// Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType.
filter_shape = conv2d.filter().getType().cast<RankedTensorType>().getShape();
SmallVector<int32_t, 4> new_filter_shape(filter_shape.begin(),
filter_shape.end());
// Rewrite Conv2DBackPropFilter after the first convolution.
for (Operation* user : conv2d.getOperation()->getUsers()) {
if (auto backprop = dyn_cast<TF::Conv2DBackpropFilterOp>(user)) {
HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape,
block_size);
}
}
}
// Get block size that is equal to stride from spatial dimension
// from convolution.
// Space to depth transform won't be triggered if block size <= 1.
int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) {
SmallVector<int32_t, 4> strides(4, 1);
for (int i = 0; i < 3; ++i) {
strides[i] = conv2d.strides()[i].cast<mlir::IntegerAttr>().getInt();
}
// Space to depth only supports striding at spatial dimension.
if (strides[0] != 1 || strides[3] != 1) return 1;
// Space to depth only supports height_stride == width_stride case.
if (strides[1] != strides[2]) return 1;
return strides[1];
}
void TPUSpaceToDepthPass::runOnOperation() {
Optional<tf_device::ClusterFuncOp> cluster_func;
// Space to depth only supports training loop.
auto func_result = getOperation().walk([&](tf_device::ClusterFuncOp cluster) {
cluster_func = cluster;
return WalkResult::interrupt();
});
// Return if there is no tf_device::ClusterFuncOp in training loop.
if (!func_result.wasInterrupted() || !cluster_func.hasValue()) {
return;
}
// Get the function on device.
auto device_func =
getOperation().lookupSymbol<mlir::FuncOp>(cluster_func->getFunc());
if (!device_func) return;
TF::Conv2DOp first_conv;
Optional<ArrayRef<int64_t>> input_shape;
// A map maps block argument id to the convolutions consumes them.
llvm::SmallDenseMap<unsigned, std::vector<Conv2DWithBlockSize>>
argnum_and_convolutions;
// A map maps block argument id to the number of users.
llvm::SmallDenseMap<unsigned, int> argnum_num_users;
// Find out the qualified convolutions and its block argument ids.
auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) {
Optional<std::pair<unsigned, int>> arg_num_and_num_users =
GetConv2DInputArgNum(conv2d);
if (arg_num_and_num_users.hasValue()) {
// Get block size for the first convolution.
int64_t block_size = GetConv2DBlockSize(conv2d);
auto arg_num = arg_num_and_num_users.getValue().first;
auto num_users = arg_num_and_num_users.getValue().second;
argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size);
argnum_num_users[arg_num] = num_users;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (!conv2d_result.wasInterrupted()) {
return;
}
// Iterate through block argument and its convolution users. Space to depth
// transform will be applied only if all the below conditions are satisfied:
// 1. All the users of the block argument will lead to convolutions;
// 2. block_size of for the space to depth transform for these convolutions
// are the same;
// 3. block_size of for the space to depth transform for these convolutions
// are larger than 1.
for (auto argnum_and_convolution : argnum_and_convolutions) {
auto arg_num = argnum_and_convolution.getFirst();
auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
// Continue if number of users of the block argment doesn't equal to number
// of transformable convolutions and there is no qualified convolution
// for transform or block size is smaller than 2.
if (argnum_num_users[arg_num] != conv2d_and_block_sizes.size() ||
conv2d_and_block_sizes.empty()) {
argnum_and_convolutions.erase(arg_num);
continue;
}
int64_t block_size = conv2d_and_block_sizes[0].second;
if (block_size < 2) {
argnum_and_convolutions.erase(arg_num);
continue;
}
// Continue if not all the block sizes for space to depth transform are the
// same.
for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
if (conv2d_and_block_size.second != block_size) {
argnum_and_convolutions.erase(arg_num);
break;
}
}
}
// If there is no qualified space to depth transform.
if (argnum_and_convolutions.empty()) {
return;
}
// Apply space to depth transform.
for (auto argnum_and_convolution : argnum_and_convolutions) {
auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
int64_t block_size = conv2d_and_block_sizes[0].second;
// Apply space to depth transform to the input on the host.
HandleCluster(cluster_func.getValue(), block_size,
argnum_and_convolution.getFirst());
// Transform the convolution.
for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
HandleFirstConvolution(conv2d_and_block_size.first,
conv2d_and_block_size.second);
}
}
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUSpaceToDepthPass() {
return std::make_unique<TPUSpaceToDepthPass>();
}
static PassRegistration<TPUSpaceToDepthPass> pass(
"tf-tpu-space-to-depth-pass",
"Adds ops that allow TPU program enable automaic space to depth for the"
"convolution determined at JIT compile time.");
} // namespace TFTPU
} // namespace mlir