Fix XLA HLO op import mapping of output_feature_dimension
PiperOrigin-RevId: 309872901 Change-Id: I4a9ff9af41e43408f437652abbd7c8ae071f4700
This commit is contained in:
parent
2f27c4b3fd
commit
6059623a34
@ -117,7 +117,7 @@ mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||
builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
|
||||
Convert(kernel_spatial_dims, builder),
|
||||
builder->getI64IntegerAttr(dnums.output_batch_dimension()),
|
||||
builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
|
||||
builder->getI64IntegerAttr(dnums.output_feature_dimension()),
|
||||
Convert(output_spatial_dims, builder), builder->getContext());
|
||||
}
|
||||
|
||||
|
@ -244,8 +244,8 @@ add {
|
||||
// CHECK-SAME: kernel_input_feature_dimension = 2 : i64
|
||||
// CHECK-SAME: kernel_output_feature_dimension = 3 : i64
|
||||
// CHECK-SAME: kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: output_batch_dimension = 0 : i64
|
||||
// CHECK-SAME: output_feature_dimension = 3 : i64
|
||||
// CHECK-SAME: output_batch_dimension = 3 : i64
|
||||
// CHECK-SAME: output_feature_dimension = 0 : i64
|
||||
// CHECK-SAME: output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
// CHECK-SAME: }
|
||||
// CHECK-SAME: feature_group_count = 1 : i64
|
||||
@ -255,11 +255,11 @@ add {
|
||||
// CHECK-SAME: rhs_dilations = dense<[2, 3]> : tensor<2xi64>
|
||||
// CHECK-SAME: window_strides = dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK-SAME: }
|
||||
// CHECK-SAME: (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK-SAME: (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<16x30x30x256xf32>
|
||||
|
||||
%convolution.4 = f32[256,30,30,16]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
%convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
|
||||
%reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
|
||||
|
Loading…
Reference in New Issue
Block a user