From 6059623a34cff974b28ae2865b25ec378bd7d615 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 4 May 2020 20:40:46 -0700 Subject: [PATCH] Fix XLA HLO op import mapping of output_feature_dimension PiperOrigin-RevId: 309872901 Change-Id: I4a9ff9af41e43408f437652abbd7c8ae071f4700 --- tensorflow/compiler/mlir/xla/attribute_importer.cc | 2 +- .../compiler/mlir/xla/tests/translate/import.hlotxt | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.cc b/tensorflow/compiler/mlir/xla/attribute_importer.cc index 2d17127b075..201ec0d053f 100644 --- a/tensorflow/compiler/mlir/xla/attribute_importer.cc +++ b/tensorflow/compiler/mlir/xla/attribute_importer.cc @@ -117,7 +117,7 @@ mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers( builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()), Convert(kernel_spatial_dims, builder), builder->getI64IntegerAttr(dnums.output_batch_dimension()), - builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()), + builder->getI64IntegerAttr(dnums.output_feature_dimension()), Convert(output_spatial_dims, builder), builder->getContext()); } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index d1133057544..75471e3a090 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -244,8 +244,8 @@ add { // CHECK-SAME: kernel_input_feature_dimension = 2 : i64 // CHECK-SAME: kernel_output_feature_dimension = 3 : i64 // CHECK-SAME: kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64> - // CHECK-SAME: output_batch_dimension = 0 : i64 - // CHECK-SAME: output_feature_dimension = 3 : i64 + // CHECK-SAME: output_batch_dimension = 3 : i64 + // CHECK-SAME: output_feature_dimension = 0 : i64 // CHECK-SAME: output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> // CHECK-SAME: } // CHECK-SAME: feature_group_count = 1 : i64 @@ -255,11 +255,11 @@ add { // CHECK-SAME: rhs_dilations = dense<[2, 3]> : tensor<2xi64> // CHECK-SAME: window_strides = dense<[4, 5]> : tensor<2xi64> // CHECK-SAME: } - // CHECK-SAME: (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<256x30x30x16xf32> + // CHECK-SAME: (tensor<256x32x32x6xf32>, tensor<2x2x1x1xf32>) -> tensor<16x30x30x256xf32> - %convolution.4 = f32[256,30,30,16]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} + %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> + // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} // CHECK-NEXT: "xla_hlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple>