Constant evaluation of XLA HLO Concatenate
PiperOrigin-RevId: 309788484 Change-Id: Ie001860de87485ab31f3e5183c9e2a0ce899014d
This commit is contained in:
parent
7f631f41f2
commit
9b2dc5ff87
tensorflow/compiler/mlir/xla
@ -842,6 +842,51 @@ void ConcatenateOp::getCanonicalizationPatterns(
|
||||
results.insert<ConcatenateOperandRemoval>(context);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Attribute foldConcatenateHelper(ConcatenateOp* op,
|
||||
ArrayRef<Attribute> operands) {
|
||||
auto axis = op->dimension().getLimitedValue();
|
||||
auto type = op->getType().cast<ShapedType>();
|
||||
|
||||
SmallVector<T, 6> values;
|
||||
auto shape = type.getShape();
|
||||
|
||||
size_t top_size = 1;
|
||||
for (int i = 0; i < axis; i++) {
|
||||
top_size = top_size * shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < top_size; i++) {
|
||||
for (auto operand : operands) {
|
||||
DenseElementsAttr attr = operand.cast<DenseElementsAttr>();
|
||||
size_t bottom_size = attr.getNumElements() / top_size;
|
||||
auto iter = attr.getValues<T>().begin() + i * bottom_size;
|
||||
values.append(iter, iter + bottom_size);
|
||||
}
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(type, values);
|
||||
}
|
||||
|
||||
static Attribute foldConcatenate(ConcatenateOp* op,
|
||||
ArrayRef<Attribute> operands) {
|
||||
for (auto operand : operands) {
|
||||
if (!operand) return {};
|
||||
}
|
||||
|
||||
auto type = op->getResult().getType().cast<ShapedType>();
|
||||
auto etype = type.getElementType();
|
||||
if (etype.isa<IntegerType>()) {
|
||||
return foldConcatenateHelper<APInt>(op, operands);
|
||||
}
|
||||
|
||||
if (etype.isa<FloatType>()) {
|
||||
return foldConcatenateHelper<APFloat>(op, operands);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (getNumOperands() == 1) return getOperand(0);
|
||||
|
||||
@ -849,6 +894,10 @@ OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!type.hasStaticShape()) return {};
|
||||
|
||||
auto axis = dimension().getLimitedValue();
|
||||
if (auto attr = foldConcatenate(this, operands)) {
|
||||
return attr;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 6> new_operands;
|
||||
for (auto operand : getOperands()) {
|
||||
auto ty = operand.getType().cast<ShapedType>();
|
||||
|
@ -43,6 +43,54 @@ func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> ten
|
||||
return %0 : tensor<0xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_const_1D
|
||||
func @concatenate_const_1D() -> tensor<4xi32> {
|
||||
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]>
|
||||
%0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32>
|
||||
%1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_const_1D_float
|
||||
func @concatenate_const_1D_float() -> tensor<4xf32> {
|
||||
// CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
|
||||
|
||||
%0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
|
||||
%1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_const_2D_vertical
|
||||
func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
|
||||
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
|
||||
// CHECK-SAME: [0, 1], [2, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
|
||||
%1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_const_2D_horizontal
|
||||
func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
|
||||
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
|
||||
// CHECK-SAME: [0, 2], [1, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
|
||||
%1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_variable_start
|
||||
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||
|
Loading…
Reference in New Issue
Block a user