diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir index 73cf4009e27..4a27e74ad70 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir @@ -5,19 +5,19 @@ func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32> return %0 : tensor<2x3x4x6xf32> // CHECK-LABEL: batchMatMulV2TwoDim - // CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64> - // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> - // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> - // CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64> - // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> - // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> - // CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64> - // CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64> + // CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>} + // CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>} + // CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} + // CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>} + // CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} + // CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[3, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[4, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[5, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_9:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>} + // CHECK: %[[cst_10:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} + // CHECK: %[[cst_11:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>} // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> @@ -67,16 +67,16 @@ func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) return %0 : tensor<3x4x6xf32> // CHECK-LABEL: batchMatMulV2FlatInput - // CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64> - // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> - // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> - // CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64> - // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> - // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> - // CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64> - // CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64> + // CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>} + // CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>} + // CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} + // CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>} + // CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} + // CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>} + // CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} + // CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>} // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32> // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> @@ -122,19 +122,19 @@ func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) return %0 : tensor<2x3x4x6xf32> // CHECK-LABEL: batchMatMulTwoDim - // CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64> - // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> - // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> - // CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64> - // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> - // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> - // CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64> - // CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64> + // CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>} + // CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>} + // CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} + // CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>} + // CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} + // CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[3, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[4, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[5, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_9:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>} + // CHECK: %[[cst_10:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} + // CHECK: %[[cst_11:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>} // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> @@ -184,16 +184,16 @@ func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) - return %0 : tensor<3x4x6xf32> // CHECK-LABEL: batchMatMulFlatInput - // CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64> - // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> - // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> - // CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64> - // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> - // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> - // CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> - // CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64> - // CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64> + // CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>} + // CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>} + // CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} + // CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>} + // CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} + // CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>} + // CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} + // CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>} // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32> // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index 7e439dcc99b..912a6aa722f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -71,7 +71,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp::createReshapeOp( Type resultType = RankedTensorType::get(shape, element_type); auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape); auto shape_tensor = - rewriter.create(loc, shape_spec_type, constant_attr); + rewriter.create(loc, shape_spec_type, constant_attr); return rewriter.create(loc, resultType, /*tensor=*/value, /*shape=*/shape_tensor); } @@ -104,8 +104,8 @@ std::vector ConvertTFBatchMatMulOp::sliceInput( auto begin_attr = DenseElementsAttr::get(vector3_type, {batch_idx, 0, 0}); auto size_attr = DenseElementsAttr::get(vector3_type, slice_size); - auto begin = rewriter.create(loc, vector3_type, begin_attr); - auto size = rewriter.create(loc, vector3_type, size_attr); + auto begin = rewriter.create(loc, vector3_type, begin_attr); + auto size = rewriter.create(loc, vector3_type, size_attr); auto slice_op = rewriter.create(loc, slice_result_type, /*input=*/reshape_op.output(), begin, size);