diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir index 280986a7ee1..ceecb3e72d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir @@ -83,5 +83,80 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0" } } -// ---- +// ----- + +// Tests for space to depth host and device transform with replicate inputs. + +module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSITE:0" = {}, "/job:localhost/replica:0/task:0/device:CPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:1" = {}, "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 458 : i32}} { + func @main(%arg0: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg5: tensor {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg7: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg8: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg9: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg10: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg11: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg12: tensor<*x!tf.resource>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}) attributes {tf.entry_function = {control_outputs = "IteratorGetNext,IteratorGetNext_1,CrossReplicaSum,AssignAddVariableOp,CrossReplicaSum_1,AssignAddVariableOp_1,CrossReplicaSum_2,AssignAddVariableOp_2,CrossReplicaSum_3,AssignAddVariableOp_3", inputs = "iterator,iterator_1,iterator_2,iterator_3,iterator_4,iterator_5,resnet50_conv1_conv2d_conv1_kernel_140365606309224_handle_inputs_0,resnet50_fc1000_matmul_fc1000_kernel_140365944145960_handle_inputs_0,resnet50_fc1000_biasadd_fc1000_bias_140365944146240_handle_inputs_0,total_140366323758976_handle_inputs_0,count_140366323759312_handle_inputs_0,total_140366323760264_handle_inputs_0,count_140366323760600_handle_inputs_0", outputs = ""}} { + // CHECK: %[[INPUT00:.*]] = "tf.IteratorGetNext" + // CHECK-DAG: %[[SPACETODEPTH00:.*]] = "tf.SpaceToDepth"([[INPUT00:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32> + %0:2 = "tf.IteratorGetNext"(%arg2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>) + // CHECK: %[[INPUT01:.*]] = "tf.IteratorGetNext" + // CHECK-DAG: %[[SPACETODEPTH01:.*]] = "tf.SpaceToDepth"([[INPUT01:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32> + %1:2 = "tf.IteratorGetNext"(%arg4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>) + tf_device.replicate([%0#0, %1#0] as %arg13: tensor<2x224x224x3xf32>, [%0#1, %1#1] as %arg14: tensor<2x1xf32>, %arg6 as %arg15: tensor<*x!tf.resource>>, %arg8 as %arg16: tensor<*x!tf.resource>>, %arg7 as %arg17: tensor<*x!tf.resource>>, %arg9 as %arg18: tensor<*x!tf.resource>>, %arg10 as %arg19: tensor<*x!tf.resource>>, %arg11 as %arg20: tensor<*x!tf.resource>>, %arg12 as %arg21: tensor<*x!tf.resource>>) {_mirrored_variable_indices = [2, 3, 4, 5, 6, 7, 8], _replicated_input_indices = [1, 2, -1, -1, -1, -1, -1, -1, -1], devices = {}, n = 2 : i32} { + %2 = "tf.ReadVariableOp"(%arg15) : (tensor<*x!tf.resource>>) -> tensor<7x7x3x64xf32> + %3 = "tf.ReadVariableOp"(%arg16) : (tensor<*x!tf.resource>>) -> tensor<1001xf32> + %4 = "tf.ReadVariableOp"(%arg17) : (tensor<*x!tf.resource>>) -> tensor<64x1001xf32> + %5 = "tf.ReadVariableOp"(%arg18) : (tensor<*x!tf.resource>>) -> tensor + %6 = "tf.ReadVariableOp"(%arg19) : (tensor<*x!tf.resource>>) -> tensor + %7 = "tf.ReadVariableOp"(%arg20) : (tensor<*x!tf.resource>>) -> tensor + %8 = "tf.ReadVariableOp"(%arg21) : (tensor<*x!tf.resource>>) -> tensor + %9:4 = "tf_device.cluster_func"(%arg13, %arg14, %2, %4, %3, %5, %6, %7, %8) {_tpu_replicate = "cluster_eval_step", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], func = @_func, host_compute_core = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], num_cores_per_replica = 1 : i64, output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : (tensor<2x224x224x3xf32>, tensor<2x1xf32>, tensor<7x7x3x64xf32>, tensor<64x1001xf32>, tensor<1001xf32>, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) + "tf.AssignVariableOp"(%arg18, %9#0) : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg19, %9#1) : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg20, %9#2) : (tensor<*x!tf.resource>>, tensor) -> () + "tf.AssignVariableOp"(%arg21, %9#3) : (tensor<*x!tf.resource>>, tensor) -> () + tf_device.return + } + return + } + // CHECK-LABEL: func @_func + // CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { + func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} { + %0 = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<[[0, 1]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32> + %4 = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %5 = "tf.Const"() {value = dense<2.500000e-01> : tensor} : () -> tensor + %6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %7 = "tf.Const"() {value = dense<[-1, 1001]> : tensor<2xi32>} : () -> tensor<2xi32> + %8 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %9 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %10 = "tf.Const"() {value = dense<[[0, 0], [3, 3], [3, 3], [0, 0]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> + %11 = "tf.Pad"(%arg0, %10) : (tensor<2x224x224x3xf32>, tensor<4x2xi32>) -> tensor<2x230x230x3xf32> + %12 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2x1xf32>) -> tensor<2x1xi64> + %13 = "tf.Reshape"(%12, %9) : (tensor<2x1xi64>, tensor<1xi32>) -> tensor<2xi64> + %14 = "tf.Squeeze"(%arg1) {squeeze_dims = [-1]} : (tensor<2x1xf32>) -> tensor<2xf32> + // CHECK: "tf.Conv2D" + // CHECK-SAME: strides = [1, 1, 1, 1] + // CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4x4x12x64xf32>) -> tensor<2x112x112x64xf32> + %15 = "tf.Conv2D"(%11, %arg2) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> tensor<2x112x112x64xf32> + %16 = "tf.Mean"(%15, %8) {keep_dims = false} : (tensor<2x112x112x64xf32>, tensor<2xi32>) -> tensor<2x64xf32> + %17 = "tf.MatMul"(%16, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x64xf32>, tensor<64x1001xf32>) -> tensor<2x1001xf32> + %18 = "tf.BiasAdd"(%17, %arg4) {data_format = "NHWC"} : (tensor<2x1001xf32>, tensor<1001xf32>) -> tensor<2x1001xf32> + %19 = "tf.Reshape"(%18, %7) : (tensor<2x1001xf32>, tensor<2xi32>) -> tensor<2x1001xf32> + %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%19, %13) : (tensor<2x1001xf32>, tensor<2xi64>) -> (tensor<2xf32>, tensor<2x1001xf32>) + %20 = "tf.Sum"(%loss, %6) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor + %21 = "tf.Mul"(%20, %5) : (tensor, tensor) -> tensor + %22 = "tf.Sum"(%21, %4) {keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %23 = "tf.CrossReplicaSum"(%22, %3) : (tensor, tensor<1x2xi32>) -> tensor + %24 = "tf.Softmax"(%18) : (tensor<2x1001xf32>) -> tensor<2x1001xf32> + %25 = "tf.ArgMax"(%24, %2) : (tensor<2x1001xf32>, tensor) -> tensor<2xi64> + %26 = "tf.Cast"(%25) {Truncate = false} : (tensor<2xi64>) -> tensor<2xf32> + %27 = "tf.Equal"(%14, %26) {incompatible_shape_error = true} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + %28 = "tf.Cast"(%27) {Truncate = false} : (tensor<2xi1>) -> tensor<2xf32> + %29 = "tf.Sum"(%28, %6) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor + %30 = "tf.CrossReplicaSum"(%29, %3) : (tensor, tensor<1x2xi32>) -> tensor + %31 = "tf.AddV2"(%arg5, %23) : (tensor, tensor) -> tensor + %32 = "tf.CrossReplicaSum"(%1, %3) : (tensor, tensor<1x2xi32>) -> tensor + %33 = "tf.AddV2"(%arg6, %32) : (tensor, tensor) -> tensor + %34 = "tf.AddV2"(%arg7, %30) : (tensor, tensor) -> tensor + %35 = "tf.CrossReplicaSum"(%0, %3) : (tensor, tensor<1x2xi32>) -> tensor + %36 = "tf.AddV2"(%arg8, %35) : (tensor, tensor) -> tensor + return %31, %33, %34, %36 : tensor, tensor, tensor, tensor + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc index 204a674e632..2f1db0899f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc @@ -54,6 +54,11 @@ namespace { constexpr char kDeviceAttr[] = "device"; typedef std::pair Conv2DWithBlockSize; +struct BlockArgumentInfo { + unsigned arg_num; + unsigned num_users; +}; + // A pass that applies automatic space to depth transform for the first or // frontier convolutions consume host inputs on TPU. // This is done by adding space to depth transform op after host input and @@ -108,7 +113,49 @@ struct TPUSpaceToDepthPass void runOnOperation() override; }; -// Handle padding before convolution for space to depth transform. +// Updates func argument type to have the updated input shape. +void UpdateFuncType(FuncOp func) { + auto arg_types = llvm::to_vector<8>(func.front().getArgumentTypes()); + auto result_types = + llvm::to_vector<4>(func.front().getTerminator()->getOperandTypes()); + func.setType(FunctionType::get(arg_types, result_types, func.getContext())); +} + +void HandleFuncOp(Operation* op) { + auto func = llvm::cast(op); + UpdateFuncType(func); +} + +// Handles cast op between the first convolution and the block argument. +LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { + auto cast_input = cast_op.x(); + // Update input type. + auto transform_result_type = + RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input)); + cast_input.setType(transform_result_type); + auto block_arg = cast_input.dyn_cast(); + auto cast_op_input = dyn_cast_or_null(cast_input.getDefiningOp()); + while (block_arg || cast_op_input) { + if (block_arg) { + // Change on device function type/shape. + HandleFuncOp(block_arg.getOwner()->getParentOp()); + block_arg = nullptr; + cast_op_input = nullptr; + } else { + auto cast_input = cast_op_input.x(); + // Update input type. + auto transform_result_type = + RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input)); + cast_input.setType(transform_result_type); + // Update block arg and cast_op_input. + block_arg = cast_input.dyn_cast(); + cast_op_input = dyn_cast_or_null(cast_input.getDefiningOp()); + } + } + return success(); +} + +// Handles padding before convolution for space to depth transform. LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { auto ranked_type = op.input().getType().dyn_cast(); if (!ranked_type) return failure(); @@ -134,6 +181,10 @@ LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { pad_input_shape[0], pad_input_shape[1] / block_size, pad_input_shape[2] / block_size, pad_input_shape[3] * block_size * block_size}; + // Input of the pad op could be a cast op. + if (auto cast_op = dyn_cast_or_null(input.getDefiningOp())) + if (failed(HandleCast(cast_op, transform_shape))) return failure(); + auto transform_result_type = RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); input.setType(transform_result_type); @@ -141,7 +192,7 @@ LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { return success(); } -// Handle stride for the first convolution for the transform. +// Handles stride for the first convolution for the transform. void HandleConv2DStride(TF::Conv2DOp conv2d) { MLIRContext* context = conv2d.getContext(); SmallVector values = {1, 1, 1, 1}; @@ -153,7 +204,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) { conv2d.setAttr("strides", strides); } -// Transform input shape for the first convolution. +// Transforms input shape for the first convolution. void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) { auto input = conv2d.input(); auto input_shape = input.getType().cast().getShape(); @@ -165,7 +216,7 @@ void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) { input.setType(transform_result_type); } -// Add padding for convolution filter for space to depth transform. +// Adds padding for convolution filter for space to depth transform. TF::PadOp GetPadOpForConv2DFilter(ArrayRef filter_shape, Value filter, OpBuilder* builder, int32_t pad_h, int32_t pad_w) { @@ -185,7 +236,7 @@ TF::PadOp GetPadOpForConv2DFilter(ArrayRef filter_shape, Value filter, paddings_value); } -// Create reshape op for space to depth transform. +// Creates reshape op for space to depth transform. TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef new_shape, Value input, OpBuilder* builder) { auto reshape_result_type = @@ -199,7 +250,7 @@ TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef new_shape, input, reshape_value); } -// Create transpose op for shape to depth transform. +// Creates transpose op for shape to depth transform. TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) { SmallVector permutation = {0, 2, 1, 3, 4, 5}; auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32)); @@ -259,7 +310,7 @@ void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) { conv2d.setOperand(1, final_reshape_op); } -// Create slice op for filter in back prop pass. +// Creates slice op for filter in back prop pass. TF::SliceOp GetSliceOpForConv2DBackPropFilter( ArrayRef old_filter_shape, Value input, OpBuilder* builder) { SmallVector slice_size(old_filter_shape.begin(), @@ -281,7 +332,7 @@ TF::SliceOp GetSliceOpForConv2DBackPropFilter( start_position, slice_size_op); } -// Transform Conv2DBackPropFilter for space to depth. +// Transforms Conv2DBackPropFilter for space to depth. void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop, ArrayRef old_filter_shape, ArrayRef new_filter_shape, @@ -354,22 +405,6 @@ void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop, backprop.replaceAllUsesWith(slice_op.getResult()); } -// Update func arugument type to have the updated input shape. -void UpdateFuncType(FuncOp func) { - llvm::SmallVector arg_types; - arg_types.reserve(func.getNumArguments()); - for (auto arg : func.getArguments()) arg_types.emplace_back(arg.getType()); - auto terminator = func.front().getTerminator(); - SmallVector result_types(terminator->operand_type_begin(), - terminator->operand_type_end()); - func.setType(FunctionType::get(arg_types, result_types, func.getContext())); -} - -void HandleFuncOp(Operation* op) { - auto func = llvm::cast(op); - UpdateFuncType(func); -} - // Checks if the input producer op is supported in this transform. Right now, we // only check if it is a host tf.IteratorGetNext. bool IsSupportedHostInputOp(Operation* op) { @@ -417,9 +452,10 @@ TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index, // supported case (thus transform happened). bool HandleHostReplicatedInputs(int64_t index, tf_device::ClusterFuncOp cluster_func, - int64_t replicate_arg_index, + BlockArgument block_arg, tf_device::ReplicateOp replicate, int32_t block_size) { + int64_t replicate_arg_index = block_arg.getArgNumber(); // We need to know the devices to copy to. if (!replicate.devices()) return false; int64_t num_replicas = replicate.n().getZExtValue(); @@ -439,6 +475,7 @@ bool HandleHostReplicatedInputs(int64_t index, BuildSpaceToDepth(cluster_func, entry.value(), block_size, input_shape); replicate.setOperand(num_replicas * replicate_arg_index + entry.index(), space_to_depth); + block_arg.setType(space_to_depth.getType()); } return true; } @@ -457,9 +494,8 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, // For a block argument, consider transforms only when it is a replicated // input (defining ops will be outside the replicate node). if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) { - HandleHostReplicatedInputs(input.index(), cluster_func, - block_arg.getArgNumber(), maybe_replicate, - block_size); + HandleHostReplicatedInputs(input.index(), cluster_func, block_arg, + maybe_replicate, block_size); } } else { // For an op output, consider transforms only when 1) there is no @@ -482,7 +518,7 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, } } -// Check if input shape of convolution is good for space to depth transform. +// Checks if input shape of convolution is good for space to depth transform. bool Conv2DInputShapeCanTransform(Value input) { auto ranked_type = input.getType().dyn_cast(); if (!ranked_type) return false; @@ -495,35 +531,59 @@ bool Conv2DInputShapeCanTransform(Value input) { return true; } -// Checks if a convoluton can apply SpaceToDepth transform. -// Only the first convolution in the graph whose batch size smaller than 8 -// and its input feature size smaller than 8 can be transformed. -Optional> GetConv2DInputArgNum(TF::Conv2DOp conv2d) { - if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) { - return None; - } - auto conv2d_input = conv2d.input(); - if (auto block_arg = conv2d_input.dyn_cast()) { - if (!Conv2DInputShapeCanTransform(conv2d_input)) return None; - int num_users = +// Get block argument id and number of users for the input arg. +Optional GetBlockArgNum(Value arg) { + if (auto block_arg = arg.dyn_cast()) { + if (!Conv2DInputShapeCanTransform(arg)) return None; + unsigned num_users = std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end()); - return std::make_pair(block_arg.getArgNumber(), num_users); + BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users}; + return block_arg_info; } + return None; +} - if (auto pad_op = llvm::dyn_cast(conv2d_input.getDefiningOp())) { - auto pad_input = pad_op.input(); - if (auto block_arg = pad_input.dyn_cast()) { - if (!Conv2DInputShapeCanTransform(pad_input)) return None; - int num_users = std::distance(block_arg.getUsers().begin(), - block_arg.getUsers().end()); - return std::make_pair(block_arg.getArgNumber(), num_users); +// Gets input block argument id and number of users for the input recursively. +// Current supported ops between convolution input and the block arguments are +// PadOp and CastOp. +Optional GetInputBlockArgNum(Value input) { + auto block_arg_num = GetBlockArgNum(input); + if (block_arg_num.hasValue()) return block_arg_num; + + Value next_input = input; + auto pad_op = dyn_cast_or_null(next_input.getDefiningOp()); + auto cast_op = dyn_cast_or_null(next_input.getDefiningOp()); + + while (pad_op || cast_op) { + if (pad_op) { + auto block_arg_num = GetBlockArgNum(pad_op.input()); + if (block_arg_num.hasValue()) return block_arg_num; + next_input = pad_op.input(); + } else { + auto block_arg_num = GetBlockArgNum(cast_op.x()); + if (block_arg_num.hasValue()) return block_arg_num; + next_input = cast_op.x(); } + pad_op = dyn_cast_or_null(next_input.getDefiningOp()); + cast_op = dyn_cast_or_null(next_input.getDefiningOp()); } return None; } -// Apply space to depth transform for the first convolution on TPU device. +// Checks if a convoluton can apply SpaceToDepth transform. +// Only the first convolution in the graph whose batch size smaller than 8 +// and its input feature size smaller than 8 can be transformed. +Optional GetConv2DInputArgNum(TF::Conv2DOp conv2d) { + if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) { + return None; + } + // Current supported ops between convolution input and the block arguments are + // PadOp and CastOp. + return GetInputBlockArgNum(conv2d.input()); +} + +// Applies space to depth transform for the first convolution on TPU device. void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { // Check if input and filter type are RankedTensorType. auto input_tensor_type = @@ -563,8 +623,9 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { SmallVector new_filter_shape(filter_shape.begin(), filter_shape.end()); - // Rewrite Conv2DBackPropFilter after the first convolution. - for (Operation* user : conv2d.getOperation()->getUsers()) { + // Rewrite Conv2DBackPropFilter that is the user of first convolution's input. + if (!conv2d_input.getDefiningOp()) return; + for (Operation* user : conv2d_input.getDefiningOp()->getUsers()) { if (auto backprop = dyn_cast(user)) { HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape, block_size); @@ -572,7 +633,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { } } -// Get block size that is equal to stride from spatial dimension +// Gets block size that is equal to stride from spatial dimension // from convolution. // Space to depth transform won't be triggered if block size <= 1. int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) { @@ -617,13 +678,13 @@ void TPUSpaceToDepthPass::runOnOperation() { // Find out the qualified convolutions and its block argument ids. auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) { - Optional> arg_num_and_num_users = + Optional arg_num_and_num_users = GetConv2DInputArgNum(conv2d); if (arg_num_and_num_users.hasValue()) { // Get block size for the first convolution. int64_t block_size = GetConv2DBlockSize(conv2d); - auto arg_num = arg_num_and_num_users.getValue().first; - auto num_users = arg_num_and_num_users.getValue().second; + auto arg_num = arg_num_and_num_users.getValue().arg_num; + auto num_users = arg_num_and_num_users.getValue().num_users; argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size); argnum_num_users[arg_num] = num_users; return WalkResult::interrupt();