From 9b2dc5ff87f6e9c73be5edb2f5dc85575a3dfc2e Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Mon, 4 May 2020 11:43:57 -0700 Subject: [PATCH] Constant evaluation of XLA HLO Concatenate PiperOrigin-RevId: 309788484 Change-Id: Ie001860de87485ab31f3e5183c9e2a0ce899014d --- tensorflow/compiler/mlir/xla/ir/hlo_ops.cc | 49 +++++++++++++++++++ .../compiler/mlir/xla/tests/canonicalize.mlir | 48 ++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index c9742ad5337..e597ca9da68 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -842,6 +842,51 @@ void ConcatenateOp::getCanonicalizationPatterns( results.insert(context); } +template +static Attribute foldConcatenateHelper(ConcatenateOp* op, + ArrayRef operands) { + auto axis = op->dimension().getLimitedValue(); + auto type = op->getType().cast(); + + SmallVector values; + auto shape = type.getShape(); + + size_t top_size = 1; + for (int i = 0; i < axis; i++) { + top_size = top_size * shape[i]; + } + + for (size_t i = 0; i < top_size; i++) { + for (auto operand : operands) { + DenseElementsAttr attr = operand.cast(); + size_t bottom_size = attr.getNumElements() / top_size; + auto iter = attr.getValues().begin() + i * bottom_size; + values.append(iter, iter + bottom_size); + } + } + + return DenseElementsAttr::get(type, values); +} + +static Attribute foldConcatenate(ConcatenateOp* op, + ArrayRef operands) { + for (auto operand : operands) { + if (!operand) return {}; + } + + auto type = op->getResult().getType().cast(); + auto etype = type.getElementType(); + if (etype.isa()) { + return foldConcatenateHelper(op, operands); + } + + if (etype.isa()) { + return foldConcatenateHelper(op, operands); + } + + return {}; +} + OpFoldResult ConcatenateOp::fold(ArrayRef operands) { if (getNumOperands() == 1) return getOperand(0); @@ -849,6 +894,10 @@ OpFoldResult ConcatenateOp::fold(ArrayRef operands) { if (!type.hasStaticShape()) return {}; auto axis = dimension().getLimitedValue(); + if (auto attr = foldConcatenate(this, operands)) { + return attr; + } + llvm::SmallVector new_operands; for (auto operand : getOperands()) { auto ty = operand.getType().cast(); diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index 5f28693c49d..3e43a8f4b3a 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -43,6 +43,54 @@ func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> ten return %0 : tensor<0xf32> } +// CHECK-LABEL: concatenate_const_1D +func @concatenate_const_1D() -> tensor<4xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]> + %0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32> + %1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_const_1D_float +func @concatenate_const_1D_float() -> tensor<4xf32> { + // CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> + + %0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32> + %1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> + + // CHECK: return [[VAL]] + return %2 : tensor<4xf32> +} + +// CHECK-LABEL: concatenate_const_2D_vertical +func @concatenate_const_2D_vertical() -> tensor<2x2xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ + // CHECK-SAME: [0, 1], [2, 3] + // CHECK-SAME: ]> + %0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32> + %1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<2x2xi32> +} + +// CHECK-LABEL: concatenate_const_2D_horizontal +func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ + // CHECK-SAME: [0, 2], [1, 3] + // CHECK-SAME: ]> + %0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32> + %1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<2x2xi32> +} // CHECK-LABEL: dynamic_slice_variable_start func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> {