From 173dd7b6d5d00211ee80d6edff05febe9b4ea3ec Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 16 Dec 2020 20:29:15 -0800 Subject: [PATCH] Integrate LLVM at llvm/llvm-project@0cf7e4b252fe Updates LLVM usage to match [0cf7e4b252fe](https://github.com/llvm/llvm-project/commit/0cf7e4b252fe) PiperOrigin-RevId: 347948887 Change-Id: I87697e08bcfc29dbf259e523eb481366b4795537 --- tensorflow/compiler/mlir/hlo/BUILD | 5 +- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 7 +- .../mhlo/transforms/legalize_control_flow.cc | 10 +- .../transforms/mhlo_control_flow_to_scf.cc | 3 +- .../mlir/hlo/tests/hlo-legalize-to-lhlo.mlir | 14 +- .../mlir/hlo/tests/legalize-control-flow.mlir | 10 +- .../mlir/hlo/tests/legalize_to_scf.mlir | 6 +- tensorflow/compiler/mlir/tensorflow/BUILD | 1 + .../tests/functional-control-flow-to-cfg.mlir | 16 +- .../functional_control_flow_to_cfg.cc | 3 +- .../tools/kernel_gen/tests/bufferize.mlir | 10 +- .../mlir/tools/kernel_gen/transforms/BUILD | 2 + .../kernel_gen/transforms/bufferize_pass.cc | 12 +- .../transforms/shape_to_descriptors_pass.cc | 2 + tensorflow/workspace.bzl | 4 +- third_party/mlir/BUILD | 372 +++++++++++++++++- third_party/mlir/test.BUILD | 12 + 17 files changed, 435 insertions(+), 54 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 219d391ab43..455f6070ca7 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -570,7 +570,7 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:TensorDialect", ], ) @@ -740,6 +740,7 @@ cc_library( "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOpsTransforms", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, @@ -809,11 +810,11 @@ cc_library( deps = [ ":hlo", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 11f91598b77..c06dfab0211 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -37,6 +38,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project namespace mlir { namespace mhlo { @@ -62,7 +64,7 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, if (shape_element.value() != ShapedType::kDynamicSize) continue; Value index = rewriter->create(loc, shape_element.index()); Value alloc_operand = - rewriter->create(loc, shape_operand, index); + rewriter->create(loc, shape_operand, index); if (!alloc_operand.getType().isIndex()) { alloc_operand = rewriter->create(loc, alloc_operand, rewriter->getIndexType()); @@ -292,7 +294,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter for (int i = 0; i < result_rank; ++i) { Value i_val = b->create(loc, i); Value result_dim_size = - b->create(loc, op.output_dimensions(), i_val); + b->create(loc, op.output_dimensions(), i_val); if (!result_dim_size.getType().isIndex()) { result_dim_size = b->create(loc, result_dim_size, b->getIndexType()); @@ -567,6 +569,7 @@ struct HloLegalizeToLhlo ConversionTarget target(context); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalDialect(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index 53472cba7c2..122000270fd 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project namespace mlir { namespace mhlo { @@ -83,7 +84,7 @@ LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) { // Extract the predicate for checking branching, then branch to the true and // false regions appropriately. - auto cond_value = builder.create(loc, if_op.pred()); + auto cond_value = builder.create(loc, if_op.pred()); builder.create(loc, cond_value, true_block, if_op.true_arg(), false_block, if_op.false_arg()); @@ -142,7 +143,7 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { builder.create(loc, cond_block, while_op.getOperand()); // Updates the inlined condition blocks by replacing the return op with an - // extract_element and conditional branch. This changes the block below: + // tensor.extract and conditional branch. This changes the block below: // ^cond(%0): // // "mhlo".return(%1) @@ -150,7 +151,7 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { // Into: // ^cond(%0): // - // %2 = extract_element %1[] : tensor // Extract the condition value. + // %2 = tensor.extract %1[] : tensor // Extract the condition value. // cond_br %2, ^body(%0), ^tail(%0) // Branch. builder.setInsertionPointToStart(cond_block); @@ -166,7 +167,8 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { builder.setInsertionPointToEnd(new_block); auto return_value = return_op.getOperand(0); - auto cond_value = builder.create(loc, return_value); + auto cond_value = + builder.create(loc, return_value); // Get the body block arguments. llvm::SmallVector successor_args(cond_block->args_begin(), diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc index b79624dae4d..d6379883639 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #define DEBUG_TYPE "mhlo-control-flow-to-scf" @@ -119,7 +120,7 @@ void MatchAndRewrite(WhileOp whileOp) { auto tensorIndexType = RankedTensorType::get({}, b.getIndexType()); auto getAsIndex = [&](Value val) { auto loc = whileOp.getLoc(); - return b.create( + return b.create( loc, b.create(loc, tensorIndexType, val), ValueRange()); }; diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index 5c05d5e946d..8cfffb37f15 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -177,16 +177,16 @@ func @dyn_broadcast(%operand: memref) -> index { // CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index // CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref -// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64> +// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64> // CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index -// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64> +// CHECK: %[[EL1:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64> // CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index // CHECK: %[[C2:.*]] = constant 2 : index -// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64> +// CHECK: %[[EL2:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64> // CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index // CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index // CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index @@ -554,9 +554,9 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) { // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> - // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> + // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64> // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index - // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64> + // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64> // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () @@ -577,9 +577,9 @@ func @tanh_dyn(%arg0: tensor) { // CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64> - // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> + // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xi64> // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index - // CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64> + // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xi64> // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir index 274792e62a2..8e5e18af7d4 100644 --- a/tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/legalize-control-flow.mlir @@ -5,7 +5,7 @@ func @while(%arg0: tensor) -> tensor { //CHECK: br ^bb1(%arg0 : tensor) //CHECK: ^bb1([[VAL0:%.+]]: tensor): //CHECK: [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]]) - //CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor + //CHECK: [[VAL2:%.+]] = tensor.extract [[VAL1]][] : tensor //CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor), ^bb3([[VAL0]] : tensor) //CHECK: ^bb2([[VAL3:%.+]]: tensor): //CHECK: [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]] @@ -33,7 +33,7 @@ func @conditional(%arg0: tensor) -> tensor { // CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor %0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - // CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor + // CHECK: [[VAL1:%.+]] = tensor.extract [[VAL0]][] : tensor // CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor), ^bb2(%arg0 : tensor) %1 = "mhlo.if"(%0, %arg0, %arg0) ( { @@ -63,7 +63,7 @@ func @while_with_multiple_blocks_in_body(%arg0: tensor) -> tensor { // CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor) // CHECK: ^[[COND_ENTRY]](%0: tensor): // CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - // CHECK: %2 = extract_element %1[] : tensor + // CHECK: %2 = tensor.extract %1[] : tensor // CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor), ^[[EXIT:.+]](%0 : tensor) // CHECK: ^[[BODY_ENTRY]](%3: tensor): // CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor) @@ -95,7 +95,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor) -> tensor { // CHECK: br ^[[COND_SUCC:.+]](%0 : tensor) // CHECK: ^[[COND_SUCC]](%1: tensor): // CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor - // CHECK: %3 = extract_element %2[] : tensor + // CHECK: %3 = tensor.extract %2[] : tensor // CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor), ^[[EXIT:.+]](%0 : tensor) // CHECK: ^[[BODY_ENTRY]](%4: tensor): // CHECK: br ^[[COND_ENTRY]](%4 : tensor) @@ -118,7 +118,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor) -> tensor { // CHECK-LABEL: func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, %pred: tensor) -> tensor { - // CHECK: %0 = extract_element %arg2[] : tensor + // CHECK: %0 = tensor.extract %arg2[] : tensor // CHECK: cond_br %0, ^[[THEN_ENTRY:.+]](%arg0 : tensor), ^[[ELSE_ENTRY:.+]](%arg1 : tensor) // CHECK: ^[[THEN_ENTRY]](%1: tensor): // CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor) diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir index 9c887a73a0f..101800d617d 100644 --- a/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/legalize_to_scf.mlir @@ -30,9 +30,9 @@ func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor, %arg2: tensor, %arg // CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor // CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor // CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor to tensor -// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor +// CHECK: %[[VAL_15:.*]] = tensor.extract %[[VAL_14]][] : tensor // CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor to tensor -// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor +// CHECK: %[[VAL_17:.*]] = tensor.extract %[[VAL_16]][] : tensor // CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor to tensor -// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor +// CHECK: %[[VAL_19:.*]] = tensor.extract %[[VAL_18]][] : tensor // CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]]) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 233d5a3ced3..8038f502bb7 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1011,6 +1011,7 @@ cc_library( "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir index 9806e7971c5..ad70631a3b1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir @@ -11,7 +11,7 @@ func @testIf1Result(tensor, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> { } : (tensor, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> // CHECK: [[TOBOOL:%.+]] = "tf.ToBool"(%arg0) : (tensor) -> tensor -// CHECK: [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor +// CHECK: [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor // CHECK: cond_br [[PRED]], ^bb1, ^bb2 // CHECK: ^bb1: // CHECK: [[THEN:%.+]] = call @testIf1Then(%arg1, %arg2) @@ -36,7 +36,7 @@ func @testIf3Result(tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, } : (tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) // CHECK: [[TOBOOL:%.+]] = "tf.ToBool"(%arg0) : (tensor) -> tensor -// CHECK: [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor +// CHECK: [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor // CHECK: cond_br [[PRED]], ^bb1, ^bb2 // CHECK: ^bb1: // CHECK: [[THEN:%.+]]:3 = call @testIf3Then(%arg1) @@ -65,7 +65,7 @@ func @testIfCasts(%arg0: tensor, %arg1: tensor>>) -> } : (tensor, tensor>>) -> tensor>> return %0: tensor>> // CHECK: [[TOBOOL:%.+]] = "tf.ToBool"(%arg0) : (tensor) -> tensor -// CHECK: [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor +// CHECK: [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor // CHECK: cond_br [[PRED]], ^bb1, ^bb2 // CHECK: ^bb1: // CHECK: [[CAST0:%.+]] = "tf.Cast"(%arg1) {Truncate = false} : (tensor>>) -> tensor @@ -93,7 +93,7 @@ func @testIf1x4(tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> { ^bb0(%arg0: tensor<4xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>): // CHECK: [[TOBOOL:%.+]] = "tf.ToBool"(%arg0) : (tensor<4xi1>) -> tensor - // CHECK: [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor + // CHECK: [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor %1 = "tf.If"(%arg0, %arg1, %arg2) { then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false } : (tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> @@ -118,7 +118,7 @@ func @testWhile2Result(tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<* // CHECK: ^bb1([[CONDARG0:%.+]]: tensor<*xf32>, [[CONDARG1:%.+]]: tensor<*xf32>): // CHECK: [[CONTINUE:%.+]] = call @testWhile2Cond(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor // CHECK: [[TOBOOL:%.+]] = "tf.ToBool"([[CONTINUE]]) : (tensor) -> tensor -// CHECK: [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor +// CHECK: [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor // CHECK: cond_br [[PRED]], ^bb2([[CONDARG0]], [[CONDARG1]] : tensor<*xf32>, tensor<*xf32>), ^bb3([[CONDARG0]], [[CONDARG1]] : tensor<*xf32>, tensor<*xf32>) // CHECK: ^bb2([[BODYARG0:%.+]]: tensor<*xf32>, [[BODYARG1:%.+]]: tensor<*xf32>): // CHECK: [[BODYRETS:%.+]]:2 = call @testWhile2Body([[BODYARG0]], [[BODYARG1]]) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) @@ -142,7 +142,7 @@ func @testWhile0Result() { // CHECK: ^bb1: // CHECK: [[CONTINUE:%.+]] = call @testWhile0Cond() : () -> tensor // CHECK: [[TOBOOL:%.+]] = "tf.ToBool"([[CONTINUE]]) : (tensor) -> tensor -// CHECK: [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor +// CHECK: [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor // CHECK: cond_br [[PRED]], ^bb2, ^bb3 // CHECK: ^bb2: // CHECK: call @testWhile0Body() : () -> () @@ -166,7 +166,7 @@ func @testComplexWhile1Result(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> (te // CHECK: ^bb1([[CONDARG0:%.+]]: tensor<*xf32>, [[CONDARG1:%.+]]: tensor<*xf32>): // CHECK: [[CONTINUE:%.+]] = call @testWhile2Cond([[CONDARG0]], [[CONDARG1]]) : (tensor<*xf32>, tensor<*xf32>) -> tensor // CHECK: [[TOBOOL:%.+]] = "tf.ToBool"([[CONTINUE]]) : (tensor) -> tensor -// CHECK: [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor +// CHECK: [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor // CHECK: cond_br [[PRED]], ^bb2([[CONDARG0]], [[CONDARG1]] : tensor<*xf32>, tensor<*xf32>), ^bb3([[CONDARG0]], [[CONDARG1]] : tensor<*xf32>, tensor<*xf32>) // CHECK: ^bb2([[BODYARG0:%.+]]: tensor<*xf32>, [[BODYARG1:%.+]]: tensor<*xf32>): // CHECK: [[BODYRETS:%.+]]:2 = call @testWhile2Body([[BODYARG0]], [[BODYARG1]]) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) @@ -206,7 +206,7 @@ func @testWhileCasts(%arg0: tensor>>) -> (tensor): // 2 preds: ^bb0, ^bb2 // CHECK: [[CONTINUE:%.+]] = call @testWhileCond([[CONDARG0]]) : (tensor) -> tensor // CHECK: [[TOBOOL:%.+]] = "tf.ToBool"([[CONTINUE]]) : (tensor) -> tensor -// CHECK: [[PRED:%.+]] = extract_element [[TOBOOL]][] : tensor +// CHECK: [[PRED:%.+]] = tensor.extract [[TOBOOL]][] : tensor // CHECK: [[CASTCONDARG0:%.+]] = "tf.Cast"([[CONDARG0]]) {Truncate = false} : (tensor) -> tensor>> // CHECK: cond_br [[PRED]], ^bb2([[CASTCONDARG0]] : tensor>>), ^bb3([[CASTCONDARG0]] : tensor>>) // CHECK: ^bb2([[BODYARG0:%.+]]: tensor>>): // pred: ^bb1 diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index a5d76619416..6adce66a094 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -17,6 +17,7 @@ limitations under the License. // TensorFlow dialect to MLIR Control Flow Graph (CFG) form. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project @@ -42,7 +43,7 @@ struct FunctionalControlFlowToCFG // control flow op into an i1 value. static Value LowerCondition(Location loc, Value value, OpBuilder* builder) { auto zero_d = builder->create(loc, value); - auto scalar = builder->create(loc, zero_d); + auto scalar = builder->create(loc, zero_d); return scalar.getResult(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir index e7211aa206b..a5286fd5f6e 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir @@ -2,14 +2,14 @@ // RUN: kernel-gen-opt %s --func-bufferize --final-bufferize --promote-buffers-to-stack | FileCheck %s --check-prefixes=CHECK,ALLOCA -// CHECK-LABEL: @extract_element +// CHECK-LABEL: @tensor.extract // CHECK-SAME: (%[[ARG:.*]]: memref) -> f32 -func @extract_element(%arg : tensor) -> f32 { +func @tensor.extract(%arg : tensor) -> f32 { // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[RESULT:.*]] = load %[[ARG]][%[[C0]]] // CHECK: return %[[RESULT]] %c0 = constant 0 : index - %result = extract_element %arg[%c0] : tensor + %result = tensor.extract %arg[%c0] : tensor return %result : f32 } @@ -30,7 +30,7 @@ func @tensor_from_elements(%a : f32) -> f32 { %c = constant 2.3 : f32 %tfe = tensor_from_elements %a, %b, %c : tensor<3xf32> %c0 = constant 0 : index - %result = extract_element %tfe[%c0] : tensor<3xf32> + %result = tensor.extract %tfe[%c0] : tensor<3xf32> return %result : f32 } @@ -54,7 +54,7 @@ func @dynamic_tensor_from_elements(%arg : tensor<*xf32>) -> index { yield %elem : index } : tensor %c0 = constant 0 : index - %result = extract_element %tfe[%c0] : tensor + %result = tensor.extract %tfe[%c0] : tensor return %result : index } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index a1648745e44..5ed47ac2416 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -87,6 +87,8 @@ cc_library( copts = if_cuda_is_configured(["-DGOOGLE_CUDA=1"]) + if_rocm_is_configured(["-DTENSORFLOW_USE_ROCM=1"]), deps = [ "@llvm-project//mlir:Affine", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TensorTransforms", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc index a43d910e96e..ba2e78ba255 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -30,6 +30,8 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" // from @llvm-project #include "mlir/Dialect/StandardOps/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Tensor/Transforms/Passes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -109,7 +111,10 @@ struct HloBufferizePass : public HloBufferizePassBase { OwningRewritePatternList patterns; auto& context = getContext(); ConversionTarget target(context); + target.addLegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalDialect(); CustomBufferizeTypeConverter converter; @@ -149,7 +154,8 @@ struct FinalBufferizePass : public FinalBufferizePassBase { // TODO(b/173201243): Move to tablegen. void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); + tensor::TensorDialect, tf_framework::TFFrameworkDialect, + lmhlo::LmhloDialect>(); } public: @@ -157,13 +163,14 @@ struct FinalBufferizePass : public FinalBufferizePassBase { auto& context = getContext(); ConversionTarget target(context); target.addLegalDialect(); target.addLegalOp(); target.addIllegalDialect(); - target.addIllegalOp(); BufferizeTypeConverter converter; @@ -175,6 +182,7 @@ struct FinalBufferizePass : public FinalBufferizePassBase { typesAreLegal); OwningRewritePatternList patterns; + populateTensorBufferizePatterns(&context, converter, patterns); populateStdBufferizePatterns(&context, converter, patterns); populateEliminateBufferizeMaterializationsPatterns(&context, converter, patterns); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc index d4a9baf17b9..7743b03bb1e 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project @@ -49,6 +50,7 @@ struct ShapeToDescriptorsPass target.addIllegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); // Don't mark the primary Cstr/Assuming ops as illegal, so they can be // lowered at a later time to assertions. target.addLegalOp