[MLIR:TPU] Improvments for space to depth pass.

PiperOrigin-RevId: 326468703
Change-Id: I9cd1a38cd652aacbe255d397071f6b82d6222194
This commit is contained in:
A. Unique TensorFlower 2020-08-13 10:02:39 -07:00 committed by TensorFlower Gardener
parent dc3f225de9
commit b07e34b7b3
2 changed files with 192 additions and 56 deletions

View File

@ -83,5 +83,80 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0"
} }
} }
// ---- // -----
// Tests for space to depth host and device transform with replicate inputs.
module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSITE:0" = {}, "/job:localhost/replica:0/task:0/device:CPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:1" = {}, "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 458 : i32}} {
func @main(%arg0: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor<!tf.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor<!tf.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg5: tensor<!tf.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf.resource<tensor<7x7x3x64xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg7: tensor<*x!tf.resource<tensor<64x1001xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg8: tensor<*x!tf.resource<tensor<1001xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg9: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg10: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg11: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg12: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}) attributes {tf.entry_function = {control_outputs = "IteratorGetNext,IteratorGetNext_1,CrossReplicaSum,AssignAddVariableOp,CrossReplicaSum_1,AssignAddVariableOp_1,CrossReplicaSum_2,AssignAddVariableOp_2,CrossReplicaSum_3,AssignAddVariableOp_3", inputs = "iterator,iterator_1,iterator_2,iterator_3,iterator_4,iterator_5,resnet50_conv1_conv2d_conv1_kernel_140365606309224_handle_inputs_0,resnet50_fc1000_matmul_fc1000_kernel_140365944145960_handle_inputs_0,resnet50_fc1000_biasadd_fc1000_bias_140365944146240_handle_inputs_0,total_140366323758976_handle_inputs_0,count_140366323759312_handle_inputs_0,total_140366323760264_handle_inputs_0,count_140366323760600_handle_inputs_0", outputs = ""}} {
// CHECK: %[[INPUT00:.*]] = "tf.IteratorGetNext"
// CHECK-DAG: %[[SPACETODEPTH00:.*]] = "tf.SpaceToDepth"([[INPUT00:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
%0:2 = "tf.IteratorGetNext"(%arg2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>)
// CHECK: %[[INPUT01:.*]] = "tf.IteratorGetNext"
// CHECK-DAG: %[[SPACETODEPTH01:.*]] = "tf.SpaceToDepth"([[INPUT01:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
%1:2 = "tf.IteratorGetNext"(%arg4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>)
tf_device.replicate([%0#0, %1#0] as %arg13: tensor<2x224x224x3xf32>, [%0#1, %1#1] as %arg14: tensor<2x1xf32>, %arg6 as %arg15: tensor<*x!tf.resource<tensor<7x7x3x64xf32>>>, %arg8 as %arg16: tensor<*x!tf.resource<tensor<1001xf32>>>, %arg7 as %arg17: tensor<*x!tf.resource<tensor<64x1001xf32>>>, %arg9 as %arg18: tensor<*x!tf.resource<tensor<f32>>>, %arg10 as %arg19: tensor<*x!tf.resource<tensor<f32>>>, %arg11 as %arg20: tensor<*x!tf.resource<tensor<f32>>>, %arg12 as %arg21: tensor<*x!tf.resource<tensor<f32>>>) {_mirrored_variable_indices = [2, 3, 4, 5, 6, 7, 8], _replicated_input_indices = [1, 2, -1, -1, -1, -1, -1, -1, -1], devices = {}, n = 2 : i32} {
%2 = "tf.ReadVariableOp"(%arg15) : (tensor<*x!tf.resource<tensor<7x7x3x64xf32>>>) -> tensor<7x7x3x64xf32>
%3 = "tf.ReadVariableOp"(%arg16) : (tensor<*x!tf.resource<tensor<1001xf32>>>) -> tensor<1001xf32>
%4 = "tf.ReadVariableOp"(%arg17) : (tensor<*x!tf.resource<tensor<64x1001xf32>>>) -> tensor<64x1001xf32>
%5 = "tf.ReadVariableOp"(%arg18) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%6 = "tf.ReadVariableOp"(%arg19) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%7 = "tf.ReadVariableOp"(%arg20) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%8 = "tf.ReadVariableOp"(%arg21) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%9:4 = "tf_device.cluster_func"(%arg13, %arg14, %2, %4, %3, %5, %6, %7, %8) {_tpu_replicate = "cluster_eval_step", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], func = @_func, host_compute_core = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], num_cores_per_replica = 1 : i64, output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : (tensor<2x224x224x3xf32>, tensor<2x1xf32>, tensor<7x7x3x64xf32>, tensor<64x1001xf32>, tensor<1001xf32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
"tf.AssignVariableOp"(%arg18, %9#0) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
"tf.AssignVariableOp"(%arg19, %9#1) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
"tf.AssignVariableOp"(%arg20, %9#2) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
"tf.AssignVariableOp"(%arg21, %9#3) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
tf_device.return
}
return
}
// CHECK-LABEL: func @_func
// CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%3 = "tf.Const"() {value = dense<[[0, 1]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
%4 = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
%5 = "tf.Const"() {value = dense<2.500000e-01> : tensor<f32>} : () -> tensor<f32>
%6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
%7 = "tf.Const"() {value = dense<[-1, 1001]> : tensor<2xi32>} : () -> tensor<2xi32>
%8 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%9 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%10 = "tf.Const"() {value = dense<[[0, 0], [3, 3], [3, 3], [0, 0]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
%11 = "tf.Pad"(%arg0, %10) : (tensor<2x224x224x3xf32>, tensor<4x2xi32>) -> tensor<2x230x230x3xf32>
%12 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2x1xf32>) -> tensor<2x1xi64>
%13 = "tf.Reshape"(%12, %9) : (tensor<2x1xi64>, tensor<1xi32>) -> tensor<2xi64>
%14 = "tf.Squeeze"(%arg1) {squeeze_dims = [-1]} : (tensor<2x1xf32>) -> tensor<2xf32>
// CHECK: "tf.Conv2D"
// CHECK-SAME: strides = [1, 1, 1, 1]
// CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4x4x12x64xf32>) -> tensor<2x112x112x64xf32>
%15 = "tf.Conv2D"(%11, %arg2) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> tensor<2x112x112x64xf32>
%16 = "tf.Mean"(%15, %8) {keep_dims = false} : (tensor<2x112x112x64xf32>, tensor<2xi32>) -> tensor<2x64xf32>
%17 = "tf.MatMul"(%16, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x64xf32>, tensor<64x1001xf32>) -> tensor<2x1001xf32>
%18 = "tf.BiasAdd"(%17, %arg4) {data_format = "NHWC"} : (tensor<2x1001xf32>, tensor<1001xf32>) -> tensor<2x1001xf32>
%19 = "tf.Reshape"(%18, %7) : (tensor<2x1001xf32>, tensor<2xi32>) -> tensor<2x1001xf32>
%loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%19, %13) : (tensor<2x1001xf32>, tensor<2xi64>) -> (tensor<2xf32>, tensor<2x1001xf32>)
%20 = "tf.Sum"(%loss, %6) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
%21 = "tf.Mul"(%20, %5) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%22 = "tf.Sum"(%21, %4) {keep_dims = false} : (tensor<f32>, tensor<0xi32>) -> tensor<f32>
%23 = "tf.CrossReplicaSum"(%22, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
%24 = "tf.Softmax"(%18) : (tensor<2x1001xf32>) -> tensor<2x1001xf32>
%25 = "tf.ArgMax"(%24, %2) : (tensor<2x1001xf32>, tensor<i32>) -> tensor<2xi64>
%26 = "tf.Cast"(%25) {Truncate = false} : (tensor<2xi64>) -> tensor<2xf32>
%27 = "tf.Equal"(%14, %26) {incompatible_shape_error = true} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
%28 = "tf.Cast"(%27) {Truncate = false} : (tensor<2xi1>) -> tensor<2xf32>
%29 = "tf.Sum"(%28, %6) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
%30 = "tf.CrossReplicaSum"(%29, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
%31 = "tf.AddV2"(%arg5, %23) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%32 = "tf.CrossReplicaSum"(%1, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
%33 = "tf.AddV2"(%arg6, %32) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%34 = "tf.AddV2"(%arg7, %30) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%35 = "tf.CrossReplicaSum"(%0, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
%36 = "tf.AddV2"(%arg8, %35) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %31, %33, %34, %36 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
}
}

View File

@ -54,6 +54,11 @@ namespace {
constexpr char kDeviceAttr[] = "device"; constexpr char kDeviceAttr[] = "device";
typedef std::pair<TF::Conv2DOp, int64_t> Conv2DWithBlockSize; typedef std::pair<TF::Conv2DOp, int64_t> Conv2DWithBlockSize;
struct BlockArgumentInfo {
unsigned arg_num;
unsigned num_users;
};
// A pass that applies automatic space to depth transform for the first or // A pass that applies automatic space to depth transform for the first or
// frontier convolutions consume host inputs on TPU. // frontier convolutions consume host inputs on TPU.
// This is done by adding space to depth transform op after host input and // This is done by adding space to depth transform op after host input and
@ -108,7 +113,49 @@ struct TPUSpaceToDepthPass
void runOnOperation() override; void runOnOperation() override;
}; };
// Handle padding before convolution for space to depth transform. // Updates func argument type to have the updated input shape.
void UpdateFuncType(FuncOp func) {
auto arg_types = llvm::to_vector<8>(func.front().getArgumentTypes());
auto result_types =
llvm::to_vector<4>(func.front().getTerminator()->getOperandTypes());
func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
}
void HandleFuncOp(Operation* op) {
auto func = llvm::cast<FuncOp>(op);
UpdateFuncType(func);
}
// Handles cast op between the first convolution and the block argument.
LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef<int64_t> new_shape) {
auto cast_input = cast_op.x();
// Update input type.
auto transform_result_type =
RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
cast_input.setType(transform_result_type);
auto block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
auto cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
while (block_arg || cast_op_input) {
if (block_arg) {
// Change on device function type/shape.
HandleFuncOp(block_arg.getOwner()->getParentOp());
block_arg = nullptr;
cast_op_input = nullptr;
} else {
auto cast_input = cast_op_input.x();
// Update input type.
auto transform_result_type =
RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
cast_input.setType(transform_result_type);
// Update block arg and cast_op_input.
block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
}
}
return success();
}
// Handles padding before convolution for space to depth transform.
LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
auto ranked_type = op.input().getType().dyn_cast<RankedTensorType>(); auto ranked_type = op.input().getType().dyn_cast<RankedTensorType>();
if (!ranked_type) return failure(); if (!ranked_type) return failure();
@ -134,6 +181,10 @@ LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
pad_input_shape[0], pad_input_shape[1] / block_size, pad_input_shape[0], pad_input_shape[1] / block_size,
pad_input_shape[2] / block_size, pad_input_shape[2] / block_size,
pad_input_shape[3] * block_size * block_size}; pad_input_shape[3] * block_size * block_size};
// Input of the pad op could be a cast op.
if (auto cast_op = dyn_cast_or_null<TF::CastOp>(input.getDefiningOp()))
if (failed(HandleCast(cast_op, transform_shape))) return failure();
auto transform_result_type = auto transform_result_type =
RankedTensorType::get(transform_shape, getElementTypeOrSelf(input)); RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
input.setType(transform_result_type); input.setType(transform_result_type);
@ -141,7 +192,7 @@ LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
return success(); return success();
} }
// Handle stride for the first convolution for the transform. // Handles stride for the first convolution for the transform.
void HandleConv2DStride(TF::Conv2DOp conv2d) { void HandleConv2DStride(TF::Conv2DOp conv2d) {
MLIRContext* context = conv2d.getContext(); MLIRContext* context = conv2d.getContext();
SmallVector<int64_t, 4> values = {1, 1, 1, 1}; SmallVector<int64_t, 4> values = {1, 1, 1, 1};
@ -153,7 +204,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) {
conv2d.setAttr("strides", strides); conv2d.setAttr("strides", strides);
} }
// Transform input shape for the first convolution. // Transforms input shape for the first convolution.
void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) { void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
auto input = conv2d.input(); auto input = conv2d.input();
auto input_shape = input.getType().cast<RankedTensorType>().getShape(); auto input_shape = input.getType().cast<RankedTensorType>().getShape();
@ -165,7 +216,7 @@ void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
input.setType(transform_result_type); input.setType(transform_result_type);
} }
// Add padding for convolution filter for space to depth transform. // Adds padding for convolution filter for space to depth transform.
TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter, TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
OpBuilder* builder, int32_t pad_h, OpBuilder* builder, int32_t pad_h,
int32_t pad_w) { int32_t pad_w) {
@ -185,7 +236,7 @@ TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
paddings_value); paddings_value);
} }
// Create reshape op for space to depth transform. // Creates reshape op for space to depth transform.
TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape, TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
Value input, OpBuilder* builder) { Value input, OpBuilder* builder) {
auto reshape_result_type = auto reshape_result_type =
@ -199,7 +250,7 @@ TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
input, reshape_value); input, reshape_value);
} }
// Create transpose op for shape to depth transform. // Creates transpose op for shape to depth transform.
TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) { TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) {
SmallVector<int32_t, 6> permutation = {0, 2, 1, 3, 4, 5}; SmallVector<int32_t, 6> permutation = {0, 2, 1, 3, 4, 5};
auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32)); auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32));
@ -259,7 +310,7 @@ void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) {
conv2d.setOperand(1, final_reshape_op); conv2d.setOperand(1, final_reshape_op);
} }
// Create slice op for filter in back prop pass. // Creates slice op for filter in back prop pass.
TF::SliceOp GetSliceOpForConv2DBackPropFilter( TF::SliceOp GetSliceOpForConv2DBackPropFilter(
ArrayRef<int32_t> old_filter_shape, Value input, OpBuilder* builder) { ArrayRef<int32_t> old_filter_shape, Value input, OpBuilder* builder) {
SmallVector<int64_t, 4> slice_size(old_filter_shape.begin(), SmallVector<int64_t, 4> slice_size(old_filter_shape.begin(),
@ -281,7 +332,7 @@ TF::SliceOp GetSliceOpForConv2DBackPropFilter(
start_position, slice_size_op); start_position, slice_size_op);
} }
// Transform Conv2DBackPropFilter for space to depth. // Transforms Conv2DBackPropFilter for space to depth.
void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop, void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
ArrayRef<int32_t> old_filter_shape, ArrayRef<int32_t> old_filter_shape,
ArrayRef<int32_t> new_filter_shape, ArrayRef<int32_t> new_filter_shape,
@ -354,22 +405,6 @@ void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
backprop.replaceAllUsesWith(slice_op.getResult()); backprop.replaceAllUsesWith(slice_op.getResult());
} }
// Update func arugument type to have the updated input shape.
void UpdateFuncType(FuncOp func) {
llvm::SmallVector<Type, 8> arg_types;
arg_types.reserve(func.getNumArguments());
for (auto arg : func.getArguments()) arg_types.emplace_back(arg.getType());
auto terminator = func.front().getTerminator();
SmallVector<Type, 4> result_types(terminator->operand_type_begin(),
terminator->operand_type_end());
func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
}
void HandleFuncOp(Operation* op) {
auto func = llvm::cast<FuncOp>(op);
UpdateFuncType(func);
}
// Checks if the input producer op is supported in this transform. Right now, we // Checks if the input producer op is supported in this transform. Right now, we
// only check if it is a host tf.IteratorGetNext. // only check if it is a host tf.IteratorGetNext.
bool IsSupportedHostInputOp(Operation* op) { bool IsSupportedHostInputOp(Operation* op) {
@ -417,9 +452,10 @@ TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index,
// supported case (thus transform happened). // supported case (thus transform happened).
bool HandleHostReplicatedInputs(int64_t index, bool HandleHostReplicatedInputs(int64_t index,
tf_device::ClusterFuncOp cluster_func, tf_device::ClusterFuncOp cluster_func,
int64_t replicate_arg_index, BlockArgument block_arg,
tf_device::ReplicateOp replicate, tf_device::ReplicateOp replicate,
int32_t block_size) { int32_t block_size) {
int64_t replicate_arg_index = block_arg.getArgNumber();
// We need to know the devices to copy to. // We need to know the devices to copy to.
if (!replicate.devices()) return false; if (!replicate.devices()) return false;
int64_t num_replicas = replicate.n().getZExtValue(); int64_t num_replicas = replicate.n().getZExtValue();
@ -439,6 +475,7 @@ bool HandleHostReplicatedInputs(int64_t index,
BuildSpaceToDepth(cluster_func, entry.value(), block_size, input_shape); BuildSpaceToDepth(cluster_func, entry.value(), block_size, input_shape);
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(), replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
space_to_depth); space_to_depth);
block_arg.setType(space_to_depth.getType());
} }
return true; return true;
} }
@ -457,9 +494,8 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
// For a block argument, consider transforms only when it is a replicated // For a block argument, consider transforms only when it is a replicated
// input (defining ops will be outside the replicate node). // input (defining ops will be outside the replicate node).
if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) { if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) {
HandleHostReplicatedInputs(input.index(), cluster_func, HandleHostReplicatedInputs(input.index(), cluster_func, block_arg,
block_arg.getArgNumber(), maybe_replicate, maybe_replicate, block_size);
block_size);
} }
} else { } else {
// For an op output, consider transforms only when 1) there is no // For an op output, consider transforms only when 1) there is no
@ -482,7 +518,7 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
} }
} }
// Check if input shape of convolution is good for space to depth transform. // Checks if input shape of convolution is good for space to depth transform.
bool Conv2DInputShapeCanTransform(Value input) { bool Conv2DInputShapeCanTransform(Value input) {
auto ranked_type = input.getType().dyn_cast<RankedTensorType>(); auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
if (!ranked_type) return false; if (!ranked_type) return false;
@ -495,35 +531,59 @@ bool Conv2DInputShapeCanTransform(Value input) {
return true; return true;
} }
// Checks if a convoluton can apply SpaceToDepth transform. // Get block argument id and number of users for the input arg.
// Only the first convolution in the graph whose batch size smaller than 8 Optional<BlockArgumentInfo> GetBlockArgNum(Value arg) {
// and its input feature size smaller than 8 can be transformed. if (auto block_arg = arg.dyn_cast<mlir::BlockArgument>()) {
Optional<std::pair<unsigned, int>> GetConv2DInputArgNum(TF::Conv2DOp conv2d) { if (!Conv2DInputShapeCanTransform(arg)) return None;
if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) { unsigned num_users =
return None;
}
auto conv2d_input = conv2d.input();
if (auto block_arg = conv2d_input.dyn_cast<mlir::BlockArgument>()) {
if (!Conv2DInputShapeCanTransform(conv2d_input)) return None;
int num_users =
std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end()); std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end());
return std::make_pair(block_arg.getArgNumber(), num_users); BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users};
return block_arg_info;
} }
return None;
}
if (auto pad_op = llvm::dyn_cast<TF::PadOp>(conv2d_input.getDefiningOp())) { // Gets input block argument id and number of users for the input recursively.
auto pad_input = pad_op.input(); // Current supported ops between convolution input and the block arguments are
if (auto block_arg = pad_input.dyn_cast<mlir::BlockArgument>()) { // PadOp and CastOp.
if (!Conv2DInputShapeCanTransform(pad_input)) return None; Optional<BlockArgumentInfo> GetInputBlockArgNum(Value input) {
int num_users = std::distance(block_arg.getUsers().begin(), auto block_arg_num = GetBlockArgNum(input);
block_arg.getUsers().end()); if (block_arg_num.hasValue()) return block_arg_num;
return std::make_pair(block_arg.getArgNumber(), num_users);
Value next_input = input;
auto pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
auto cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
while (pad_op || cast_op) {
if (pad_op) {
auto block_arg_num = GetBlockArgNum(pad_op.input());
if (block_arg_num.hasValue()) return block_arg_num;
next_input = pad_op.input();
} else {
auto block_arg_num = GetBlockArgNum(cast_op.x());
if (block_arg_num.hasValue()) return block_arg_num;
next_input = cast_op.x();
} }
pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
} }
return None; return None;
} }
// Apply space to depth transform for the first convolution on TPU device. // Checks if a convoluton can apply SpaceToDepth transform.
// Only the first convolution in the graph whose batch size smaller than 8
// and its input feature size smaller than 8 can be transformed.
Optional<BlockArgumentInfo> GetConv2DInputArgNum(TF::Conv2DOp conv2d) {
if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) {
return None;
}
// Current supported ops between convolution input and the block arguments are
// PadOp and CastOp.
return GetInputBlockArgNum(conv2d.input());
}
// Applies space to depth transform for the first convolution on TPU device.
void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
// Check if input and filter type are RankedTensorType. // Check if input and filter type are RankedTensorType.
auto input_tensor_type = auto input_tensor_type =
@ -563,8 +623,9 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
SmallVector<int32_t, 4> new_filter_shape(filter_shape.begin(), SmallVector<int32_t, 4> new_filter_shape(filter_shape.begin(),
filter_shape.end()); filter_shape.end());
// Rewrite Conv2DBackPropFilter after the first convolution. // Rewrite Conv2DBackPropFilter that is the user of first convolution's input.
for (Operation* user : conv2d.getOperation()->getUsers()) { if (!conv2d_input.getDefiningOp()) return;
for (Operation* user : conv2d_input.getDefiningOp()->getUsers()) {
if (auto backprop = dyn_cast<TF::Conv2DBackpropFilterOp>(user)) { if (auto backprop = dyn_cast<TF::Conv2DBackpropFilterOp>(user)) {
HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape, HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape,
block_size); block_size);
@ -572,7 +633,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
} }
} }
// Get block size that is equal to stride from spatial dimension // Gets block size that is equal to stride from spatial dimension
// from convolution. // from convolution.
// Space to depth transform won't be triggered if block size <= 1. // Space to depth transform won't be triggered if block size <= 1.
int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) { int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) {
@ -617,13 +678,13 @@ void TPUSpaceToDepthPass::runOnOperation() {
// Find out the qualified convolutions and its block argument ids. // Find out the qualified convolutions and its block argument ids.
auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) { auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) {
Optional<std::pair<unsigned, int>> arg_num_and_num_users = Optional<BlockArgumentInfo> arg_num_and_num_users =
GetConv2DInputArgNum(conv2d); GetConv2DInputArgNum(conv2d);
if (arg_num_and_num_users.hasValue()) { if (arg_num_and_num_users.hasValue()) {
// Get block size for the first convolution. // Get block size for the first convolution.
int64_t block_size = GetConv2DBlockSize(conv2d); int64_t block_size = GetConv2DBlockSize(conv2d);
auto arg_num = arg_num_and_num_users.getValue().first; auto arg_num = arg_num_and_num_users.getValue().arg_num;
auto num_users = arg_num_and_num_users.getValue().second; auto num_users = arg_num_and_num_users.getValue().num_users;
argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size); argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size);
argnum_num_users[arg_num] = num_users; argnum_num_users[arg_num] = num_users;
return WalkResult::interrupt(); return WalkResult::interrupt();