diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 5110ea7fbf5..05b2f891676 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -452,6 +452,7 @@ cc_library( "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_rewrite_pass.cc", "transforms/tpu_sharding_identification_pass.cc", + "transforms/tpu_space_to_depth_pass.cc", "transforms/tpu_variable_runtime_reformatting.cc", "translate/breakup-islands.cc", "translate/control_to_executor_dialect.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir new file mode 100644 index 00000000000..aa333caa2ae --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir @@ -0,0 +1,87 @@ +// RUN: tf-opt %s -split-input-file -tf-tpu-space-to-depth-pass | FileCheck %s --dump-input=fail + +// Tests for space to depth host and device transform. + +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:1" = {}, "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 390 : i32}} { + func @main(%arg0: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg2: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg4: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg5: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg6: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg7: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) attributes {tf.entry_function = {control_outputs = "while", inputs = "iterator,iterator_1,iterator_2,iterator_3,while_input_6,while_input_7,while_input_8,while_input_9", outputs = ""}} { + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %3:10 = "tf.While"(%2, %1, %2, %0, %1, %arg2, %arg4, %arg5, %arg6, %arg7) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_2710, cond = @while_cond_2700, device = "", is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) + return + } + // CHECK-LABEL: func @while_body_2710 + func @while_body_2710(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg7: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg8: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg9: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) attributes {tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[INPUT:.*]] = "tf.IteratorGetNext" + %1 = "tf.IteratorGetNext"(%arg5) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor<2x224x224x3xf32> + // CHECK-DAG: %[[SPACETODEPTH0:.*]] = "tf.SpaceToDepth"([[INPUT:.*]]) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32> + %2 = "tf.AddV2"(%arg2, %arg3) {device = ""} : (tensor, tensor) -> tensor + %3 = "tf.ReadVariableOp"(%arg6) : (tensor>>) -> tensor<7x7x3x64xf32> + %4 = "tf.ReadVariableOp"(%arg8) : (tensor>>) -> tensor + %5 = "tf.ReadVariableOp"(%arg7) : (tensor>>) -> tensor + %6 = "tf.ReadVariableOp"(%arg9) : (tensor>>) -> tensor + %7:2 = "tf_device.cluster_func"(%1, %3, %5, %6) {_tpu_replicate = "while/cluster_while_body_271", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0], func = @_func, host_compute_core = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], num_cores_per_replica = 1 : i64, output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", use_tpu = true} : (tensor<2x224x224x3xf32>, tensor<7x7x3x64xf32>, tensor, tensor) -> (tensor<7x7x3x64xf32>, tensor) + "tf.AssignVariableOp"(%arg6, %7#0) : (tensor>>, tensor<7x7x3x64xf32>) -> () + "tf.AssignVariableOp"(%arg9, %7#1) : (tensor>>, tensor) -> () + %8 = "tf.Identity"(%arg1) {device = ""} : (tensor) -> tensor + %9 = "tf.Identity"(%2) {device = ""} : (tensor) -> tensor + %10 = "tf.AddV2"(%arg0, %0) {device = ""} : (tensor, tensor) -> tensor + %11 = "tf.Identity"(%10) {device = ""} : (tensor) -> tensor + return %11, %8, %9, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9 : tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>> + } + func @while_cond_2700(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg7: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg8: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}, %arg9: tensor>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"}) -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %1 = "tf.GreaterEqual"(%arg3, %0) {device = ""} : (tensor, tensor) -> tensor + %2 = "tf.Less"(%arg3, %0) {device = ""} : (tensor, tensor) -> tensor + %3 = "tf.Greater"(%arg2, %arg4) {device = ""} : (tensor, tensor) -> tensor + %4 = "tf.LogicalAnd"(%2, %3) {device = ""} : (tensor, tensor) -> tensor + %5 = "tf.Less"(%arg2, %arg4) {device = ""} : (tensor, tensor) -> tensor + %6 = "tf.LogicalAnd"(%1, %5) {device = ""} : (tensor, tensor) -> tensor + %7 = "tf.LogicalOr"(%6, %4) {device = ""} : (tensor, tensor) -> tensor + %8 = "tf.Less"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor + %9 = "tf.LogicalAnd"(%8, %7) {device = ""} : (tensor, tensor) -> tensor + %10 = "tf.Identity"(%9) {device = ""} : (tensor) -> tensor + return %10 : tensor + } + // CHECK-LABEL: func @_func + // CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { + func @_func(%arg0: tensor<2x224x224x3xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32> + %2 = "tf.Const"() {value = dense<[7, 7, 3, 64]> : tensor<4xi32>} : () -> tensor<4xi32> + %3 = "tf.Const"() {value = dense<[[0, 0], [3, 3], [3, 3], [0, 0]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> + %4 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %5 = "tf.Pad"(%arg0, %3) : (tensor<2x224x224x3xf32>, tensor<4x2xi32>) -> tensor<2x230x230x3xf32> + // CHECK: "tf.Conv2D" + // CHECK-SAME: strides = [1, 1, 1, 1] + // CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4x4x12x64xf32>) -> tensor<2x112x112x64xf32> + %6 = "tf.Conv2D"(%5, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> tensor<2x112x112x64xf32> + // CHECK: %[[BACKPROP:.*]] = "tf.Conv2DBackpropFilter" + // CHECK-SAME: strides = [1, 1, 1, 1] + // CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4xi32>, tensor<2x112x112x64xf32>) -> tensor<4x4x12x64xf32> + %7 = "tf.Conv2DBackpropFilter"(%5, %2, %6) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<4xi32>, tensor<2x112x112x64xf32>) -> tensor<7x7x3x64xf32> + // CHECK: %[[CONST0:.*]] = "tf.Const"() {value = dense< + // CHECK-SAME: [4, 4, 2, 2, 3, 64] + // CHECK: %[[RESHAPE0:.*]] = "tf.Reshape"(%[[BACKPROP:.*]], %[[CONST0:.*]]) : (tensor<4x4x12x64xf32>, tensor<6xi64>) -> tensor<4x4x2x2x3x64xf32> + // CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense< + // CHECK-SAME: [0, 2, 1, 3, 4, 5] + // CHECK: %[[TRANSPOSE0:.*]] = "tf.Transpose"(%[[RESHAPE0:.*]], %[[CONST1:.*]]) : (tensor<4x4x2x2x3x64xf32>, tensor<6xi32>) -> tensor<4x2x4x2x3x64xf32> + // CHECK: %[[CONST2:.*]] = "tf.Const"() {value = dense< + // CHECK-SAME: [8, 8, 3, 64] + // CHECK: %[[RESHAPE1:.*]] = "tf.Reshape"(%[[TRANSPOSE1:.*]], %[[CONST2:.*]]) : (tensor<4x2x4x2x3x64xf32>, tensor<4xi64>) -> tensor<8x8x3x64xf32> + // CHECK: %[[CONST3:.*]] = "tf.Const"() {value = dense< + // CHECK-SAME: [7, 7, 3, 64] + // CHECK: %[[CONST4:.*]] = "tf.Const"() {value = dense< + // CHECK-SAME: 0 + // CHECK: %[[SLICE0:.*]] = "tf.Slice"(%[[RESHAPE1:.*]], %[[CONST4:.*]], %[[CONST3:.*]]) : (tensor<8x8x3x64xf32>, tensor<4xi64>, tensor<4xi32>) -> tensor<7x7x3x64xf32> + %8 = "tf.CrossReplicaSum"(%7, %1) : (tensor<7x7x3x64xf32>, tensor<1x1xi32>) -> tensor<7x7x3x64xf32> + %9 = "tf.Mul"(%arg2, %8) : (tensor, tensor<7x7x3x64xf32>) -> tensor<7x7x3x64xf32> + %10 = "tf.Sub"(%arg1, %9) : (tensor<7x7x3x64xf32>, tensor<7x7x3x64xf32>) -> tensor<7x7x3x64xf32> + %11 = "tf.AddV2"(%arg3, %0) : (tensor, tensor) -> tensor + return %10, %11 : tensor<7x7x3x64xf32>, tensor + } +} + +// ---- + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc new file mode 100644 index 00000000000..7befa68f3d8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc @@ -0,0 +1,703 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace mlir { +namespace TFTPU { + +namespace { + +constexpr char kDeviceAttr[] = "device"; +typedef std::pair Conv2DWithBlockSize; + +// A pass that applies automatic space to depth transform for the first or +// frontier convolutions consume host inputs on TPU. +// This is done by adding space to depth transform op after host input and +// applying space to depth transform for the first convolution and its backprop +// filter on TPU. +// +// Example: original program: +// +// module { +// func @while_body { +// %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}: +// -> tensor<2x224x224x3xf32> +// %device_launch = "tf_device.cluster_func"(%input,...) {func = @_func,...) +// return ... +// } +// func @_func(%input: tensor<2x224x224x3xf32>, +// %filter: tensor<7x7x3x64xf32>) { +// %6 = "tf.Conv2D"(%input, %filter) {strides = [1, 2, 2, 1]}: +// (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> +// tensor<2x112x112x64xf32> +// } +// } +// +// With this pass, the program will be transformed into: +// module { +// func @while_body { +// %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"} +// -> tensor<2x224x224x3xf32> +// %space_to_depth = "tf.SpaceToDepth"(%input) {block_size = 2, ...}: +// (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32> +// %device_launch = "tf_device.cluster_func"(%space_to_depth,...) +// {func = @_func,...) +// return ... +// } +// func @_func(%input: tensor<2x112x112x12xf32>, +// %filter: tensor<7x7x3x64xf32>) { +// %filter_transform = "tf.Pad/tf.Transpose/tf.Reshape"(%filter): +// tensor<7x7x3x64xf32>) -> tensor<4x4x12x64xf32> +// %conv = "tf.Conv2D"(%input, %filter_transfrom) {strides = [1, 1, 1, 1]}: +// (tensor<2x112x112x12xf32>, tensor<4x4x12x64xf32>) -> +// tensor<2x112x112x64xf32> +// } +// } +// +// This way, the first convolution with 3 feature dimension will be transformed +// to 12 feature dimension, which has better performance on TPU. +// +// TODO(wangtao): add a pass to check if it is profitable to space to depth +// transform and invoke the transform if it is needed. +struct TPUSpaceToDepthPass + : public PassWrapper> { + void runOnOperation() override; +}; + +// Handle padding before convolution for space to depth transform. +LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { + auto ranked_type = op.input().getType().dyn_cast(); + if (!ranked_type) return failure(); + auto pad_input_shape = ranked_type.getShape(); + Location loc = op.getLoc(); + OpBuilder builder(op); + builder.setInsertionPoint(op); + auto padding_type = RankedTensorType::get({4, 2}, builder.getIntegerType(32)); + + // Calculate paddings. + int32_t pad_total = kernel_size - 1; + int32_t pad_beg = (pad_total / 2 + 1) / block_size; + int32_t pad_end = (pad_total / 2) / block_size; + SmallVector values = {0, 0, pad_beg, pad_end, + pad_beg, pad_end, 0, 0}; + auto paddings = DenseIntElementsAttr::get(padding_type, values); + // Update pad_op paddings. + op.setOperand(1, builder.create(loc, paddings)); + + // Set input type. + auto input = op.getOperand(0); + SmallVector transform_shape = { + pad_input_shape[0], pad_input_shape[1] / block_size, + pad_input_shape[2] / block_size, + pad_input_shape[3] * block_size * block_size}; + auto transform_result_type = + RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); + input.setType(transform_result_type); + op.setOperand(0, input); + return success(); +} + +// Handle stride for the first convolution for the transform. +void HandleConv2DStride(TF::Conv2DOp conv2d) { + MLIRContext* context = conv2d.getContext(); + SmallVector values = {1, 1, 1, 1}; + auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { + return IntegerAttr::get(IntegerType::get(64, context), v); + }); + // TODO(b/157276506): change type of strides to DenseElementsAttr + auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context); + conv2d.setAttr("strides", strides); +} + +// Transform input shape for the first convolution. +void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) { + auto input = conv2d.input(); + auto input_shape = input.getType().cast().getShape(); + SmallVector transform_shape = { + input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size, + input_shape[3] * block_size * block_size}; + auto transform_result_type = + RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); + input.setType(transform_result_type); +} + +// Add padding for convolution filter for space to depth transform. +TF::PadOp GetPadOpForConv2DFilter(ArrayRef filter_shape, Value filter, + OpBuilder* builder, int32_t pad_h, + int32_t pad_w) { + SmallVector values = {pad_h, 0, pad_w, 0, 0, 0, 0, 0}; + auto padding_type = + RankedTensorType::get({4, 2}, builder->getIntegerType(32)); + auto paddings = DenseIntElementsAttr::get(padding_type, values); + auto paddings_value = builder->create(filter.getLoc(), paddings); + std::vector pad_shape = {filter_shape[0] + pad_h, + filter_shape[1] + pad_w, filter_shape[2], + filter_shape[3]}; + SmallVector expand_shape(pad_shape.begin(), pad_shape.end()); + + auto expand_result_type = + RankedTensorType::get(expand_shape, getElementTypeOrSelf(filter)); + return builder->create(filter.getLoc(), expand_result_type, filter, + paddings_value); +} + +// Create reshape op for space to depth transform. +TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef new_shape, + Value input, OpBuilder* builder) { + auto reshape_result_type = + RankedTensorType::get(new_shape, getElementTypeOrSelf(input)); + auto reshape_type = RankedTensorType::get( + {static_cast(new_shape.size())}, builder->getIntegerType(64)); + auto reshape_sizes = DenseIntElementsAttr::get(reshape_type, new_shape); + auto reshape_value = + builder->create(input.getLoc(), reshape_sizes); + return builder->create(input.getLoc(), reshape_result_type, + input, reshape_value); +} + +// Create transpose op for shape to depth transform. +TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) { + SmallVector permutation = {0, 2, 1, 3, 4, 5}; + auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32)); + auto permute_attr = DenseIntElementsAttr::get(permute_type, permutation); + auto permute_value = + builder->create(input.getLoc(), permute_attr); + return builder->create(input.getLoc(), input, permute_value); +} + +void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) { + // For example, if filter shape is [7, 7, 3, 64] with block_size 2, + // will apply below transforms to the filter: + // 1. Pad the filter to [8, 8, 3, 64] + // 2. Reshape to [4, 2, 4, 2, 3, 64] + // 3. Transpose to [4, 4, 2, 2, 3, 64] + // 4. Reshape to [4, 4, 12, 64] + auto filter = conv2d.filter(); + OpBuilder builder(conv2d); + builder.setInsertionPoint(conv2d); + // Book keeping filter information. + auto filter_shape = filter.getType().cast().getShape(); + int64_t height = filter_shape[0]; + int64_t width = filter_shape[1]; + int64_t channel = filter_shape[2]; + int64_t out_channel = filter_shape[3]; + // Value/Op before reshape op. + Value before_reshape_value = filter; + if (height % block_size != 0 || width % block_size != 0) { + // Calculate paddings for height and width. + int32_t pad_h = block_size - height % block_size; + int32_t pad_w = block_size - width % block_size; + auto pad_op = + GetPadOpForConv2DFilter(filter_shape, filter, &builder, pad_h, pad_w); + // Update op, height and width before reshape. + before_reshape_value = pad_op; + height = height + pad_h; + width = width + pad_w; + } + + // Reshape. + SmallVector new_shape = { + height / block_size, block_size, width / block_size, + block_size, channel, out_channel}; + auto reshape_op = + GetReshapeOpForConv2DFilter(new_shape, before_reshape_value, &builder); + + // Transpose. + auto transpose_op = GetTransposeOpForConv2DFilter(&builder, reshape_op); + + // Reshape Back. + SmallVector final_shape = { + height / block_size, width / block_size, + channel * block_size * block_size, out_channel}; + auto final_reshape_op = + GetReshapeOpForConv2DFilter(final_shape, transpose_op, &builder); + // Update filter of Conv2D. + conv2d.setOperand(1, final_reshape_op); +} + +// Create slice op for filter in back prop pass. +TF::SliceOp GetSliceOpForConv2DBackPropFilter( + ArrayRef old_filter_shape, Value input, OpBuilder* builder) { + SmallVector slice_size(old_filter_shape.begin(), + old_filter_shape.end()); + auto slice_result_type = + RankedTensorType::get(slice_size, getElementTypeOrSelf(input)); + auto slice_size_op = builder->create( + input.getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({4}, builder->getIntegerType(32)), + old_filter_shape)); + SmallVector slice_start_position = {0, 0, 0, 0}; + auto start_position_type = + RankedTensorType::get({4}, builder->getIntegerType(64)); + auto start_position = builder->create( + input.getLoc(), + DenseIntElementsAttr::get(start_position_type, slice_start_position)); + return builder->create(input.getLoc(), slice_result_type, input, + start_position, slice_size_op); +} + +// Transform Conv2DBackPropFilter for space to depth. +void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop, + ArrayRef old_filter_shape, + ArrayRef new_filter_shape, + int64_t block_size) { + OpBuilder builder(backprop); + builder.setInsertionPoint(backprop); + + auto input = backprop.input(); + // Get new filter size from new_filter_shape. + auto new_filter_sizes = builder.create( + backprop.getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({4}, builder.getIntegerType(32)), + new_filter_shape)); + + // Set stride to [1, 1, 1, 1]. + MLIRContext* context = backprop.getContext(); + SmallVector values = {1, 1, 1, 1}; + auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { + return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v)); + }); + auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context); + + // new result type. + SmallVector new_shape(new_filter_shape.begin(), + new_filter_shape.end()); + auto new_result_type = + RankedTensorType::get(new_shape, getElementTypeOrSelf(input)); + + // Build new BackPropFilterOp. + auto loc = backprop.getLoc(); + auto new_backprop = builder.create( + loc, new_result_type, input, new_filter_sizes, backprop.out_backprop(), + strides, backprop.use_cudnn_on_gpu(), backprop.padding(), + backprop.explicit_paddings(), backprop.data_format(), + backprop.dilations()); + + // For example, if new filter shape is [4, 4, 12, 64], old filter shape + // is [7, 7, 3, 64] with block_size 2. + // Below transforms will be applied to the filter: + // 1. Reshape to [4, 4, 2, 2, 3, 64]; + // 2. Transpose to [4, 2, 4, 2, 3, 64]; + // 3. Reshape to [8, 8, 3, 64]; + // 4. Slice to [7, 7, 3, 64]. + SmallVector first_reshape_shape = { + new_filter_shape[0], + new_filter_shape[1], + block_size, + block_size, + new_filter_shape[2] / (block_size * block_size), + new_filter_shape[3]}; + auto first_reshape_op = + GetReshapeOpForConv2DFilter(first_reshape_shape, new_backprop, &builder); + + // Transpose. + auto transpose_op = GetTransposeOpForConv2DFilter(&builder, first_reshape_op); + + // Last Reshape op. + SmallVector last_reshape_shape = { + new_filter_shape[0] * block_size, new_filter_shape[1] * block_size, + new_filter_shape[2] / (block_size * block_size), new_filter_shape[3]}; + auto final_reshape_op = + GetReshapeOpForConv2DFilter(last_reshape_shape, transpose_op, &builder); + + // create slice op. + auto slice_op = GetSliceOpForConv2DBackPropFilter(old_filter_shape, + final_reshape_op, &builder); + + // Update backprop's user with the slice op. + backprop.replaceAllUsesWith(slice_op.getResult()); +} + +// Update func arugument type to have the updated input shape. +void UpdateFuncType(FuncOp func) { + llvm::SmallVector arg_types; + arg_types.reserve(func.getNumArguments()); + for (auto arg : func.getArguments()) arg_types.emplace_back(arg.getType()); + auto terminator = func.front().getTerminator(); + SmallVector result_types(terminator->operand_type_begin(), + terminator->operand_type_end()); + func.setType(FunctionType::get(arg_types, result_types, func.getContext())); +} + +void HandleFuncOp(Operation* op) { + auto func = llvm::cast(op); + UpdateFuncType(func); +} + +// Checks if the input producer op is supported in this transform. Right now, we +// only check if it is a host tf.IteratorGetNext. +bool IsSupportedHostInputOp(Operation* op) { + TF::IteratorGetNextOp iter = llvm::dyn_cast(op); + if (!iter) return false; + auto device = op->getAttrOfType(kDeviceAttr); + if (!device) return false; + tensorflow::DeviceNameUtils::ParsedName parsed_device; + if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(), + &parsed_device)) { + return false; + } + return parsed_device.type == "CPU"; +} + +// Builds a SpaceToDepthOp with the given get_layout op and input. +TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func, + Value input, int32_t block_size, + ArrayRef input_shape) { + auto input_op = input.getDefiningOp(); + OpBuilder builder(input_op); + builder.setInsertionPointAfter(input_op); + SmallVector transform_shape = { + input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size, + input_shape[3] * block_size * block_size}; + auto transform_result_type = + RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); + return builder.create(cluster_func.getLoc(), + transform_result_type, input, + APInt(64, block_size)); +} + +// Performs transformation for a non-replicated input. +TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index, + tf_device::ClusterFuncOp cluster_func, + int32_t block_size, + ArrayRef input_shape) { + auto space_to_depth = + BuildSpaceToDepth(cluster_func, input, block_size, input_shape); + cluster_func.setOperand(index, space_to_depth); + return space_to_depth; +} + +// Performs transformation for replicated inputs. Returns true if this is a +// supported case (thus transform happened). +bool HandleHostReplicatedInputs(int64_t index, + tf_device::ClusterFuncOp cluster_func, + int64_t replicate_arg_index, + tf_device::ReplicateOp replicate, + int32_t block_size) { + // We need to know the devices to copy to. + if (!replicate.devices()) return false; + int64_t num_replicas = replicate.n().getZExtValue(); + // Gets inputs at replicate_arg_index for each replica. + auto inputs = replicate.getOperands() + .drop_front(replicate_arg_index * num_replicas) + .take_front(num_replicas); + for (auto input : inputs) { + auto input_op = input.getDefiningOp(); + if (!input_op || !IsSupportedHostInputOp(input_op)) return false; + } + for (auto entry : llvm::enumerate(inputs)) { + auto ranked_type = entry.value().getType().dyn_cast(); + if (!ranked_type) return false; + auto input_shape = ranked_type.getShape(); + auto space_to_depth = + BuildSpaceToDepth(cluster_func, entry.value(), block_size, input_shape); + replicate.setOperand(num_replicas * replicate_arg_index + entry.index(), + space_to_depth); + } + return true; +} + +// Performs transformation on a pair of execute and compile ops. The compile +// should not have other uses. +void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, + unsigned arg_num) { + auto maybe_replicate = + llvm::dyn_cast(cluster_func.getParentOp()); + + llvm::SmallVector transform_input_indices; + for (auto input : llvm::enumerate(cluster_func.operands())) { + if (auto block_arg = input.value().dyn_cast()) { + if (block_arg.getArgNumber() != arg_num) continue; + // For a block argument, consider transforms only when it is a replicated + // input (defining ops will be outside the replicate node). + if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) { + HandleHostReplicatedInputs(input.index(), cluster_func, + block_arg.getArgNumber(), maybe_replicate, + block_size); + } + } else { + // For an op output, consider transforms only when 1) there is no + // replicateion or 2) it is outside the replicate node that encloses the + // execute node. (Because if the op is inside replicate, it is probably + // not on the host.) + if (input.index() != arg_num) continue; + auto input_op = input.value().getDefiningOp(); + if (maybe_replicate && + maybe_replicate.body().isAncestor(input_op->getParentRegion())) { + continue; + } + if (!IsSupportedHostInputOp(input_op)) continue; + auto ranked_type = input.value().getType().dyn_cast(); + if (!ranked_type) continue; + auto input_shape = ranked_type.getShape(); + HandleHostInput(input.value(), input.index(), cluster_func, block_size, + input_shape); + } + } +} + +// Check if input shape of convolution is good for space to depth transform. +bool Conv2DInputShapeCanTransform(Value input) { + auto ranked_type = input.getType().dyn_cast(); + if (!ranked_type) return false; + auto input_shape = ranked_type.getShape(); + int32_t batch_size = input_shape[0]; + int32_t channel = input_shape[3]; + if (batch_size > 8 || channel > 8) { + return false; + } + return true; +} + +// Checks if a convoluton can apply SpaceToDepth transform. +// Only the first convolution in the graph whose batch size smaller than 8 +// and its input feature size smaller than 8 can be transformed. +Optional> GetConv2DInputArgNum(TF::Conv2DOp conv2d) { + if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) { + return None; + } + auto conv2d_input = conv2d.input(); + if (auto block_arg = conv2d_input.dyn_cast()) { + if (!Conv2DInputShapeCanTransform(conv2d_input)) return None; + int num_users = + std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end()); + return std::make_pair(block_arg.getArgNumber(), num_users); + } + + if (auto pad_op = llvm::dyn_cast(conv2d_input.getDefiningOp())) { + auto pad_input = pad_op.input(); + if (auto block_arg = pad_input.dyn_cast()) { + if (!Conv2DInputShapeCanTransform(pad_input)) return None; + int num_users = std::distance(block_arg.getUsers().begin(), + block_arg.getUsers().end()); + return std::make_pair(block_arg.getArgNumber(), num_users); + } + } + + return None; +} + +// Apply space to depth transform for the first convolution on TPU device. +void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { + // Check if input and filter type are RankedTensorType. + auto input_tensor_type = + conv2d.input().getType().dyn_cast(); + auto filter_tensor_type = + conv2d.filter().getType().dyn_cast(); + if (!input_tensor_type || !filter_tensor_type) return; + // Book keeping filter shape for padding and backprop filter rewrite. + auto filter_shape = filter_tensor_type.getShape(); + SmallVector old_filter_shape(filter_shape.begin(), + filter_shape.end()); + // Handles input. + auto conv2d_input = conv2d.input(); + if (auto block_arg = conv2d_input.dyn_cast()) { + // Change on device function type/shape. + HandleFuncOp(block_arg.getOwner()->getParentOp()); + } + + if (auto pad_op = dyn_cast_or_null(conv2d_input.getDefiningOp())) { + // Rewrite pad_op before Convolutioin. + if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return; + auto pad_input = pad_op.input(); + if (auto block_arg = pad_input.dyn_cast()) { + // Change on device function type/shape. + HandleFuncOp(block_arg.getOwner()->getParentOp()); + } + } + + // Handle Conv2D input, stride and filter. + HandleConv2DInput(conv2d, block_size); + HandleConv2DStride(conv2d); + HandleConv2DFilter(conv2d, block_size); + + // Book keeping new filter shape for backprop filter rewrite. + // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType. + filter_shape = conv2d.filter().getType().cast().getShape(); + SmallVector new_filter_shape(filter_shape.begin(), + filter_shape.end()); + + // Rewrite Conv2DBackPropFilter after the first convolution. + for (Operation* user : conv2d.getOperation()->getUsers()) { + if (auto backprop = dyn_cast(user)) { + HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape, + block_size); + } + } +} + +// Get block size that is equal to stride from spatial dimension +// from convolution. +// Space to depth transform won't be triggered if block size <= 1. +int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) { + SmallVector strides(4, 1); + for (int i = 0; i < 3; ++i) { + strides[i] = conv2d.strides()[i].cast().getInt(); + } + + // Space to depth only supports striding at spatial dimension. + if (strides[0] != 1 || strides[3] != 1) return 1; + + // Space to depth only supports height_stride == width_stride case. + if (strides[1] != strides[2]) return 1; + + return strides[1]; +} + +void TPUSpaceToDepthPass::runOnOperation() { + Optional cluster_func; + // Space to depth only supports training loop. + auto func_result = getOperation().walk([&](tf_device::ClusterFuncOp cluster) { + cluster_func = cluster; + return WalkResult::interrupt(); + }); + + // Return if there is no tf_device::ClusterFuncOp in training loop. + if (!func_result.wasInterrupted() || !cluster_func.hasValue()) { + return; + } + + // Get the function on device. + auto device_func = + getOperation().lookupSymbol(cluster_func->getFunc()); + if (!device_func) return; + + TF::Conv2DOp first_conv; + Optional> input_shape; + // A map maps block argument id to the convolutions consumes them. + llvm::SmallDenseMap> + argnum_and_convolutions; + // A map maps block argument id to the number of users. + llvm::SmallDenseMap argnum_num_users; + + // Find out the qualified convolutions and its block argument ids. + auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) { + Optional> arg_num_and_num_users = + GetConv2DInputArgNum(conv2d); + if (arg_num_and_num_users.hasValue()) { + // Get block size for the first convolution. + int64_t block_size = GetConv2DBlockSize(conv2d); + auto arg_num = arg_num_and_num_users.getValue().first; + auto num_users = arg_num_and_num_users.getValue().second; + argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size); + argnum_num_users[arg_num] = num_users; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (!conv2d_result.wasInterrupted()) { + return; + } + + // Iterate through block argument and its convolution users. Space to depth + // transform will be applied only if all the below conditions are satisfied: + // 1. All the users of the block argument will lead to convolutions; + // 2. block_size of for the space to depth transform for these convolutions + // are the same; + // 3. block_size of for the space to depth transform for these convolutions + // are larger than 1. + for (auto argnum_and_convolution : argnum_and_convolutions) { + auto arg_num = argnum_and_convolution.getFirst(); + auto conv2d_and_block_sizes = argnum_and_convolution.getSecond(); + // Continue if number of users of the block argment doesn't equal to number + // of transformable convolutions and there is no qualified convolution + // for transform or block size is smaller than 2. + if (argnum_num_users[arg_num] != conv2d_and_block_sizes.size() || + conv2d_and_block_sizes.empty()) { + argnum_and_convolutions.erase(arg_num); + continue; + } + int64_t block_size = conv2d_and_block_sizes[0].second; + if (block_size < 2) { + argnum_and_convolutions.erase(arg_num); + continue; + } + // Continue if not all the block sizes for space to depth transform are the + // same. + for (auto conv2d_and_block_size : conv2d_and_block_sizes) { + if (conv2d_and_block_size.second != block_size) { + argnum_and_convolutions.erase(arg_num); + break; + } + } + } + + // If there is no qualified space to depth transform. + if (argnum_and_convolutions.empty()) { + return; + } + + // Apply space to depth transform. + for (auto argnum_and_convolution : argnum_and_convolutions) { + auto conv2d_and_block_sizes = argnum_and_convolution.getSecond(); + int64_t block_size = conv2d_and_block_sizes[0].second; + // Apply space to depth transform to the input on the host. + HandleCluster(cluster_func.getValue(), block_size, + argnum_and_convolution.getFirst()); + // Transform the convolution. + for (auto conv2d_and_block_size : conv2d_and_block_sizes) { + HandleFirstConvolution(conv2d_and_block_size.first, + conv2d_and_block_size.second); + } + } +} + +} // namespace + +std::unique_ptr> CreateTPUSpaceToDepthPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-space-to-depth-pass", + "Adds ops that allow TPU program enable automaic space to depth for the" + "convolution determined at JIT compile time."); + +} // namespace TFTPU +} // namespace mlir