[MLIR:TPU] Improvments for space to depth pass.
PiperOrigin-RevId: 326468703 Change-Id: I9cd1a38cd652aacbe255d397071f6b82d6222194
This commit is contained in:
parent
dc3f225de9
commit
b07e34b7b3
@ -83,5 +83,80 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
}
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// Tests for space to depth host and device transform with replicate inputs.
|
||||
|
||||
module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSITE:0" = {}, "/job:localhost/replica:0/task:0/device:CPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:0" = {}, "/job:localhost/replica:0/task:0/device:TPU:1" = {}, "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0" = {}}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 458 : i32}} {
|
||||
func @main(%arg0: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg1: tensor<!tf.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg3: tensor<!tf.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf.resource> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg5: tensor<!tf.variant> {tf._user_specified_name = "iterator", tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf.resource<tensor<7x7x3x64xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg7: tensor<*x!tf.resource<tensor<64x1001xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg8: tensor<*x!tf.resource<tensor<1001xf32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg9: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg10: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg11: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}, %arg12: tensor<*x!tf.resource<tensor<f32>>> {tf._composite_device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0", tf.device = "/job:localhost/replica:0/task:0/device:COMPOSITE:0"}) attributes {tf.entry_function = {control_outputs = "IteratorGetNext,IteratorGetNext_1,CrossReplicaSum,AssignAddVariableOp,CrossReplicaSum_1,AssignAddVariableOp_1,CrossReplicaSum_2,AssignAddVariableOp_2,CrossReplicaSum_3,AssignAddVariableOp_3", inputs = "iterator,iterator_1,iterator_2,iterator_3,iterator_4,iterator_5,resnet50_conv1_conv2d_conv1_kernel_140365606309224_handle_inputs_0,resnet50_fc1000_matmul_fc1000_kernel_140365944145960_handle_inputs_0,resnet50_fc1000_biasadd_fc1000_bias_140365944146240_handle_inputs_0,total_140366323758976_handle_inputs_0,count_140366323759312_handle_inputs_0,total_140366323760264_handle_inputs_0,count_140366323760600_handle_inputs_0", outputs = ""}} {
|
||||
// CHECK: %[[INPUT00:.*]] = "tf.IteratorGetNext"
|
||||
// CHECK-DAG: %[[SPACETODEPTH00:.*]] = "tf.SpaceToDepth"([[INPUT00:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
|
||||
%0:2 = "tf.IteratorGetNext"(%arg2) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>)
|
||||
// CHECK: %[[INPUT01:.*]] = "tf.IteratorGetNext"
|
||||
// CHECK-DAG: %[[SPACETODEPTH01:.*]] = "tf.SpaceToDepth"([[INPUT01:.*]]#0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
|
||||
%1:2 = "tf.IteratorGetNext"(%arg4) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<2x224x224x3xf32>, tensor<2x1xf32>)
|
||||
tf_device.replicate([%0#0, %1#0] as %arg13: tensor<2x224x224x3xf32>, [%0#1, %1#1] as %arg14: tensor<2x1xf32>, %arg6 as %arg15: tensor<*x!tf.resource<tensor<7x7x3x64xf32>>>, %arg8 as %arg16: tensor<*x!tf.resource<tensor<1001xf32>>>, %arg7 as %arg17: tensor<*x!tf.resource<tensor<64x1001xf32>>>, %arg9 as %arg18: tensor<*x!tf.resource<tensor<f32>>>, %arg10 as %arg19: tensor<*x!tf.resource<tensor<f32>>>, %arg11 as %arg20: tensor<*x!tf.resource<tensor<f32>>>, %arg12 as %arg21: tensor<*x!tf.resource<tensor<f32>>>) {_mirrored_variable_indices = [2, 3, 4, 5, 6, 7, 8], _replicated_input_indices = [1, 2, -1, -1, -1, -1, -1, -1, -1], devices = {}, n = 2 : i32} {
|
||||
%2 = "tf.ReadVariableOp"(%arg15) : (tensor<*x!tf.resource<tensor<7x7x3x64xf32>>>) -> tensor<7x7x3x64xf32>
|
||||
%3 = "tf.ReadVariableOp"(%arg16) : (tensor<*x!tf.resource<tensor<1001xf32>>>) -> tensor<1001xf32>
|
||||
%4 = "tf.ReadVariableOp"(%arg17) : (tensor<*x!tf.resource<tensor<64x1001xf32>>>) -> tensor<64x1001xf32>
|
||||
%5 = "tf.ReadVariableOp"(%arg18) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%6 = "tf.ReadVariableOp"(%arg19) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%7 = "tf.ReadVariableOp"(%arg20) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%8 = "tf.ReadVariableOp"(%arg21) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%9:4 = "tf_device.cluster_func"(%arg13, %arg14, %2, %4, %3, %5, %6, %7, %8) {_tpu_replicate = "cluster_eval_step", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], func = @_func, host_compute_core = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], num_cores_per_replica = 1 : i64, output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"], padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : (tensor<2x224x224x3xf32>, tensor<2x1xf32>, tensor<7x7x3x64xf32>, tensor<64x1001xf32>, tensor<1001xf32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
|
||||
"tf.AssignVariableOp"(%arg18, %9#0) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
"tf.AssignVariableOp"(%arg19, %9#1) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
"tf.AssignVariableOp"(%arg20, %9#2) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
"tf.AssignVariableOp"(%arg21, %9#3) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @_func
|
||||
// CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
|
||||
func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%3 = "tf.Const"() {value = dense<[[0, 1]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
|
||||
%4 = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
%5 = "tf.Const"() {value = dense<2.500000e-01> : tensor<f32>} : () -> tensor<f32>
|
||||
%6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%7 = "tf.Const"() {value = dense<[-1, 1001]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%8 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%9 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%10 = "tf.Const"() {value = dense<[[0, 0], [3, 3], [3, 3], [0, 0]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32>
|
||||
%11 = "tf.Pad"(%arg0, %10) : (tensor<2x224x224x3xf32>, tensor<4x2xi32>) -> tensor<2x230x230x3xf32>
|
||||
%12 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2x1xf32>) -> tensor<2x1xi64>
|
||||
%13 = "tf.Reshape"(%12, %9) : (tensor<2x1xi64>, tensor<1xi32>) -> tensor<2xi64>
|
||||
%14 = "tf.Squeeze"(%arg1) {squeeze_dims = [-1]} : (tensor<2x1xf32>) -> tensor<2xf32>
|
||||
// CHECK: "tf.Conv2D"
|
||||
// CHECK-SAME: strides = [1, 1, 1, 1]
|
||||
// CHECK-SAME: (tensor<2x115x115x12xf32>, tensor<4x4x12x64xf32>) -> tensor<2x112x112x64xf32>
|
||||
%15 = "tf.Conv2D"(%11, %arg2) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> tensor<2x112x112x64xf32>
|
||||
%16 = "tf.Mean"(%15, %8) {keep_dims = false} : (tensor<2x112x112x64xf32>, tensor<2xi32>) -> tensor<2x64xf32>
|
||||
%17 = "tf.MatMul"(%16, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x64xf32>, tensor<64x1001xf32>) -> tensor<2x1001xf32>
|
||||
%18 = "tf.BiasAdd"(%17, %arg4) {data_format = "NHWC"} : (tensor<2x1001xf32>, tensor<1001xf32>) -> tensor<2x1001xf32>
|
||||
%19 = "tf.Reshape"(%18, %7) : (tensor<2x1001xf32>, tensor<2xi32>) -> tensor<2x1001xf32>
|
||||
%loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%19, %13) : (tensor<2x1001xf32>, tensor<2xi64>) -> (tensor<2xf32>, tensor<2x1001xf32>)
|
||||
%20 = "tf.Sum"(%loss, %6) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
|
||||
%21 = "tf.Mul"(%20, %5) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%22 = "tf.Sum"(%21, %4) {keep_dims = false} : (tensor<f32>, tensor<0xi32>) -> tensor<f32>
|
||||
%23 = "tf.CrossReplicaSum"(%22, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
|
||||
%24 = "tf.Softmax"(%18) : (tensor<2x1001xf32>) -> tensor<2x1001xf32>
|
||||
%25 = "tf.ArgMax"(%24, %2) : (tensor<2x1001xf32>, tensor<i32>) -> tensor<2xi64>
|
||||
%26 = "tf.Cast"(%25) {Truncate = false} : (tensor<2xi64>) -> tensor<2xf32>
|
||||
%27 = "tf.Equal"(%14, %26) {incompatible_shape_error = true} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
|
||||
%28 = "tf.Cast"(%27) {Truncate = false} : (tensor<2xi1>) -> tensor<2xf32>
|
||||
%29 = "tf.Sum"(%28, %6) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
|
||||
%30 = "tf.CrossReplicaSum"(%29, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
|
||||
%31 = "tf.AddV2"(%arg5, %23) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%32 = "tf.CrossReplicaSum"(%1, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
|
||||
%33 = "tf.AddV2"(%arg6, %32) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%34 = "tf.AddV2"(%arg7, %30) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%35 = "tf.CrossReplicaSum"(%0, %3) : (tensor<f32>, tensor<1x2xi32>) -> tensor<f32>
|
||||
%36 = "tf.AddV2"(%arg8, %35) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %31, %33, %34, %36 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -54,6 +54,11 @@ namespace {
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
typedef std::pair<TF::Conv2DOp, int64_t> Conv2DWithBlockSize;
|
||||
|
||||
struct BlockArgumentInfo {
|
||||
unsigned arg_num;
|
||||
unsigned num_users;
|
||||
};
|
||||
|
||||
// A pass that applies automatic space to depth transform for the first or
|
||||
// frontier convolutions consume host inputs on TPU.
|
||||
// This is done by adding space to depth transform op after host input and
|
||||
@ -108,7 +113,49 @@ struct TPUSpaceToDepthPass
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
// Handle padding before convolution for space to depth transform.
|
||||
// Updates func argument type to have the updated input shape.
|
||||
void UpdateFuncType(FuncOp func) {
|
||||
auto arg_types = llvm::to_vector<8>(func.front().getArgumentTypes());
|
||||
auto result_types =
|
||||
llvm::to_vector<4>(func.front().getTerminator()->getOperandTypes());
|
||||
func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
|
||||
}
|
||||
|
||||
void HandleFuncOp(Operation* op) {
|
||||
auto func = llvm::cast<FuncOp>(op);
|
||||
UpdateFuncType(func);
|
||||
}
|
||||
|
||||
// Handles cast op between the first convolution and the block argument.
|
||||
LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef<int64_t> new_shape) {
|
||||
auto cast_input = cast_op.x();
|
||||
// Update input type.
|
||||
auto transform_result_type =
|
||||
RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
|
||||
cast_input.setType(transform_result_type);
|
||||
auto block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
|
||||
auto cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
|
||||
while (block_arg || cast_op_input) {
|
||||
if (block_arg) {
|
||||
// Change on device function type/shape.
|
||||
HandleFuncOp(block_arg.getOwner()->getParentOp());
|
||||
block_arg = nullptr;
|
||||
cast_op_input = nullptr;
|
||||
} else {
|
||||
auto cast_input = cast_op_input.x();
|
||||
// Update input type.
|
||||
auto transform_result_type =
|
||||
RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
|
||||
cast_input.setType(transform_result_type);
|
||||
// Update block arg and cast_op_input.
|
||||
block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
|
||||
cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Handles padding before convolution for space to depth transform.
|
||||
LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
|
||||
auto ranked_type = op.input().getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked_type) return failure();
|
||||
@ -134,6 +181,10 @@ LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
|
||||
pad_input_shape[0], pad_input_shape[1] / block_size,
|
||||
pad_input_shape[2] / block_size,
|
||||
pad_input_shape[3] * block_size * block_size};
|
||||
// Input of the pad op could be a cast op.
|
||||
if (auto cast_op = dyn_cast_or_null<TF::CastOp>(input.getDefiningOp()))
|
||||
if (failed(HandleCast(cast_op, transform_shape))) return failure();
|
||||
|
||||
auto transform_result_type =
|
||||
RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
|
||||
input.setType(transform_result_type);
|
||||
@ -141,7 +192,7 @@ LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// Handle stride for the first convolution for the transform.
|
||||
// Handles stride for the first convolution for the transform.
|
||||
void HandleConv2DStride(TF::Conv2DOp conv2d) {
|
||||
MLIRContext* context = conv2d.getContext();
|
||||
SmallVector<int64_t, 4> values = {1, 1, 1, 1};
|
||||
@ -153,7 +204,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) {
|
||||
conv2d.setAttr("strides", strides);
|
||||
}
|
||||
|
||||
// Transform input shape for the first convolution.
|
||||
// Transforms input shape for the first convolution.
|
||||
void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
|
||||
auto input = conv2d.input();
|
||||
auto input_shape = input.getType().cast<RankedTensorType>().getShape();
|
||||
@ -165,7 +216,7 @@ void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
|
||||
input.setType(transform_result_type);
|
||||
}
|
||||
|
||||
// Add padding for convolution filter for space to depth transform.
|
||||
// Adds padding for convolution filter for space to depth transform.
|
||||
TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
|
||||
OpBuilder* builder, int32_t pad_h,
|
||||
int32_t pad_w) {
|
||||
@ -185,7 +236,7 @@ TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
|
||||
paddings_value);
|
||||
}
|
||||
|
||||
// Create reshape op for space to depth transform.
|
||||
// Creates reshape op for space to depth transform.
|
||||
TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
|
||||
Value input, OpBuilder* builder) {
|
||||
auto reshape_result_type =
|
||||
@ -199,7 +250,7 @@ TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
|
||||
input, reshape_value);
|
||||
}
|
||||
|
||||
// Create transpose op for shape to depth transform.
|
||||
// Creates transpose op for shape to depth transform.
|
||||
TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) {
|
||||
SmallVector<int32_t, 6> permutation = {0, 2, 1, 3, 4, 5};
|
||||
auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32));
|
||||
@ -259,7 +310,7 @@ void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) {
|
||||
conv2d.setOperand(1, final_reshape_op);
|
||||
}
|
||||
|
||||
// Create slice op for filter in back prop pass.
|
||||
// Creates slice op for filter in back prop pass.
|
||||
TF::SliceOp GetSliceOpForConv2DBackPropFilter(
|
||||
ArrayRef<int32_t> old_filter_shape, Value input, OpBuilder* builder) {
|
||||
SmallVector<int64_t, 4> slice_size(old_filter_shape.begin(),
|
||||
@ -281,7 +332,7 @@ TF::SliceOp GetSliceOpForConv2DBackPropFilter(
|
||||
start_position, slice_size_op);
|
||||
}
|
||||
|
||||
// Transform Conv2DBackPropFilter for space to depth.
|
||||
// Transforms Conv2DBackPropFilter for space to depth.
|
||||
void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
|
||||
ArrayRef<int32_t> old_filter_shape,
|
||||
ArrayRef<int32_t> new_filter_shape,
|
||||
@ -354,22 +405,6 @@ void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
|
||||
backprop.replaceAllUsesWith(slice_op.getResult());
|
||||
}
|
||||
|
||||
// Update func arugument type to have the updated input shape.
|
||||
void UpdateFuncType(FuncOp func) {
|
||||
llvm::SmallVector<Type, 8> arg_types;
|
||||
arg_types.reserve(func.getNumArguments());
|
||||
for (auto arg : func.getArguments()) arg_types.emplace_back(arg.getType());
|
||||
auto terminator = func.front().getTerminator();
|
||||
SmallVector<Type, 4> result_types(terminator->operand_type_begin(),
|
||||
terminator->operand_type_end());
|
||||
func.setType(FunctionType::get(arg_types, result_types, func.getContext()));
|
||||
}
|
||||
|
||||
void HandleFuncOp(Operation* op) {
|
||||
auto func = llvm::cast<FuncOp>(op);
|
||||
UpdateFuncType(func);
|
||||
}
|
||||
|
||||
// Checks if the input producer op is supported in this transform. Right now, we
|
||||
// only check if it is a host tf.IteratorGetNext.
|
||||
bool IsSupportedHostInputOp(Operation* op) {
|
||||
@ -417,9 +452,10 @@ TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index,
|
||||
// supported case (thus transform happened).
|
||||
bool HandleHostReplicatedInputs(int64_t index,
|
||||
tf_device::ClusterFuncOp cluster_func,
|
||||
int64_t replicate_arg_index,
|
||||
BlockArgument block_arg,
|
||||
tf_device::ReplicateOp replicate,
|
||||
int32_t block_size) {
|
||||
int64_t replicate_arg_index = block_arg.getArgNumber();
|
||||
// We need to know the devices to copy to.
|
||||
if (!replicate.devices()) return false;
|
||||
int64_t num_replicas = replicate.n().getZExtValue();
|
||||
@ -439,6 +475,7 @@ bool HandleHostReplicatedInputs(int64_t index,
|
||||
BuildSpaceToDepth(cluster_func, entry.value(), block_size, input_shape);
|
||||
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
|
||||
space_to_depth);
|
||||
block_arg.setType(space_to_depth.getType());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -457,9 +494,8 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
|
||||
// For a block argument, consider transforms only when it is a replicated
|
||||
// input (defining ops will be outside the replicate node).
|
||||
if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) {
|
||||
HandleHostReplicatedInputs(input.index(), cluster_func,
|
||||
block_arg.getArgNumber(), maybe_replicate,
|
||||
block_size);
|
||||
HandleHostReplicatedInputs(input.index(), cluster_func, block_arg,
|
||||
maybe_replicate, block_size);
|
||||
}
|
||||
} else {
|
||||
// For an op output, consider transforms only when 1) there is no
|
||||
@ -482,7 +518,7 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if input shape of convolution is good for space to depth transform.
|
||||
// Checks if input shape of convolution is good for space to depth transform.
|
||||
bool Conv2DInputShapeCanTransform(Value input) {
|
||||
auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
|
||||
if (!ranked_type) return false;
|
||||
@ -495,35 +531,59 @@ bool Conv2DInputShapeCanTransform(Value input) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Checks if a convoluton can apply SpaceToDepth transform.
|
||||
// Only the first convolution in the graph whose batch size smaller than 8
|
||||
// and its input feature size smaller than 8 can be transformed.
|
||||
Optional<std::pair<unsigned, int>> GetConv2DInputArgNum(TF::Conv2DOp conv2d) {
|
||||
if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) {
|
||||
return None;
|
||||
}
|
||||
auto conv2d_input = conv2d.input();
|
||||
if (auto block_arg = conv2d_input.dyn_cast<mlir::BlockArgument>()) {
|
||||
if (!Conv2DInputShapeCanTransform(conv2d_input)) return None;
|
||||
int num_users =
|
||||
// Get block argument id and number of users for the input arg.
|
||||
Optional<BlockArgumentInfo> GetBlockArgNum(Value arg) {
|
||||
if (auto block_arg = arg.dyn_cast<mlir::BlockArgument>()) {
|
||||
if (!Conv2DInputShapeCanTransform(arg)) return None;
|
||||
unsigned num_users =
|
||||
std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end());
|
||||
return std::make_pair(block_arg.getArgNumber(), num_users);
|
||||
BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users};
|
||||
return block_arg_info;
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
if (auto pad_op = llvm::dyn_cast<TF::PadOp>(conv2d_input.getDefiningOp())) {
|
||||
auto pad_input = pad_op.input();
|
||||
if (auto block_arg = pad_input.dyn_cast<mlir::BlockArgument>()) {
|
||||
if (!Conv2DInputShapeCanTransform(pad_input)) return None;
|
||||
int num_users = std::distance(block_arg.getUsers().begin(),
|
||||
block_arg.getUsers().end());
|
||||
return std::make_pair(block_arg.getArgNumber(), num_users);
|
||||
// Gets input block argument id and number of users for the input recursively.
|
||||
// Current supported ops between convolution input and the block arguments are
|
||||
// PadOp and CastOp.
|
||||
Optional<BlockArgumentInfo> GetInputBlockArgNum(Value input) {
|
||||
auto block_arg_num = GetBlockArgNum(input);
|
||||
if (block_arg_num.hasValue()) return block_arg_num;
|
||||
|
||||
Value next_input = input;
|
||||
auto pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
|
||||
auto cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
|
||||
|
||||
while (pad_op || cast_op) {
|
||||
if (pad_op) {
|
||||
auto block_arg_num = GetBlockArgNum(pad_op.input());
|
||||
if (block_arg_num.hasValue()) return block_arg_num;
|
||||
next_input = pad_op.input();
|
||||
} else {
|
||||
auto block_arg_num = GetBlockArgNum(cast_op.x());
|
||||
if (block_arg_num.hasValue()) return block_arg_num;
|
||||
next_input = cast_op.x();
|
||||
}
|
||||
pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
|
||||
cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
|
||||
}
|
||||
|
||||
return None;
|
||||
}
|
||||
|
||||
// Apply space to depth transform for the first convolution on TPU device.
|
||||
// Checks if a convoluton can apply SpaceToDepth transform.
|
||||
// Only the first convolution in the graph whose batch size smaller than 8
|
||||
// and its input feature size smaller than 8 can be transformed.
|
||||
Optional<BlockArgumentInfo> GetConv2DInputArgNum(TF::Conv2DOp conv2d) {
|
||||
if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) {
|
||||
return None;
|
||||
}
|
||||
// Current supported ops between convolution input and the block arguments are
|
||||
// PadOp and CastOp.
|
||||
return GetInputBlockArgNum(conv2d.input());
|
||||
}
|
||||
|
||||
// Applies space to depth transform for the first convolution on TPU device.
|
||||
void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
|
||||
// Check if input and filter type are RankedTensorType.
|
||||
auto input_tensor_type =
|
||||
@ -563,8 +623,9 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
|
||||
SmallVector<int32_t, 4> new_filter_shape(filter_shape.begin(),
|
||||
filter_shape.end());
|
||||
|
||||
// Rewrite Conv2DBackPropFilter after the first convolution.
|
||||
for (Operation* user : conv2d.getOperation()->getUsers()) {
|
||||
// Rewrite Conv2DBackPropFilter that is the user of first convolution's input.
|
||||
if (!conv2d_input.getDefiningOp()) return;
|
||||
for (Operation* user : conv2d_input.getDefiningOp()->getUsers()) {
|
||||
if (auto backprop = dyn_cast<TF::Conv2DBackpropFilterOp>(user)) {
|
||||
HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape,
|
||||
block_size);
|
||||
@ -572,7 +633,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
|
||||
}
|
||||
}
|
||||
|
||||
// Get block size that is equal to stride from spatial dimension
|
||||
// Gets block size that is equal to stride from spatial dimension
|
||||
// from convolution.
|
||||
// Space to depth transform won't be triggered if block size <= 1.
|
||||
int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) {
|
||||
@ -617,13 +678,13 @@ void TPUSpaceToDepthPass::runOnOperation() {
|
||||
|
||||
// Find out the qualified convolutions and its block argument ids.
|
||||
auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) {
|
||||
Optional<std::pair<unsigned, int>> arg_num_and_num_users =
|
||||
Optional<BlockArgumentInfo> arg_num_and_num_users =
|
||||
GetConv2DInputArgNum(conv2d);
|
||||
if (arg_num_and_num_users.hasValue()) {
|
||||
// Get block size for the first convolution.
|
||||
int64_t block_size = GetConv2DBlockSize(conv2d);
|
||||
auto arg_num = arg_num_and_num_users.getValue().first;
|
||||
auto num_users = arg_num_and_num_users.getValue().second;
|
||||
auto arg_num = arg_num_and_num_users.getValue().arg_num;
|
||||
auto num_users = arg_num_and_num_users.getValue().num_users;
|
||||
argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size);
|
||||
argnum_num_users[arg_num] = num_users;
|
||||
return WalkResult::interrupt();
|
||||
|
Loading…
Reference in New Issue
Block a user