Constant evaluation of XLA HLO Concatenate

PiperOrigin-RevId: 309788484
Change-Id: Ie001860de87485ab31f3e5183c9e2a0ce899014d
This commit is contained in:
Robert Suderman 2020-05-04 11:43:57 -07:00 committed by TensorFlower Gardener
parent 7f631f41f2
commit 9b2dc5ff87
2 changed files with 97 additions and 0 deletions
tensorflow/compiler/mlir/xla

View File

@ -842,6 +842,51 @@ void ConcatenateOp::getCanonicalizationPatterns(
results.insert<ConcatenateOperandRemoval>(context);
}
template <typename T>
static Attribute foldConcatenateHelper(ConcatenateOp* op,
ArrayRef<Attribute> operands) {
auto axis = op->dimension().getLimitedValue();
auto type = op->getType().cast<ShapedType>();
SmallVector<T, 6> values;
auto shape = type.getShape();
size_t top_size = 1;
for (int i = 0; i < axis; i++) {
top_size = top_size * shape[i];
}
for (size_t i = 0; i < top_size; i++) {
for (auto operand : operands) {
DenseElementsAttr attr = operand.cast<DenseElementsAttr>();
size_t bottom_size = attr.getNumElements() / top_size;
auto iter = attr.getValues<T>().begin() + i * bottom_size;
values.append(iter, iter + bottom_size);
}
}
return DenseElementsAttr::get(type, values);
}
static Attribute foldConcatenate(ConcatenateOp* op,
ArrayRef<Attribute> operands) {
for (auto operand : operands) {
if (!operand) return {};
}
auto type = op->getResult().getType().cast<ShapedType>();
auto etype = type.getElementType();
if (etype.isa<IntegerType>()) {
return foldConcatenateHelper<APInt>(op, operands);
}
if (etype.isa<FloatType>()) {
return foldConcatenateHelper<APFloat>(op, operands);
}
return {};
}
OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
if (getNumOperands() == 1) return getOperand(0);
@ -849,6 +894,10 @@ OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
if (!type.hasStaticShape()) return {};
auto axis = dimension().getLimitedValue();
if (auto attr = foldConcatenate(this, operands)) {
return attr;
}
llvm::SmallVector<Value, 6> new_operands;
for (auto operand : getOperands()) {
auto ty = operand.getType().cast<ShapedType>();

View File

@ -43,6 +43,54 @@ func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> ten
return %0 : tensor<0xf32>
}
// CHECK-LABEL: concatenate_const_1D
func @concatenate_const_1D() -> tensor<4xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]>
%0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32>
%1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
// CHECK: return [[VAL]]
return %2 : tensor<4xi32>
}
// CHECK-LABEL: concatenate_const_1D_float
func @concatenate_const_1D_float() -> tensor<4xf32> {
// CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
%0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
%1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
// CHECK: return [[VAL]]
return %2 : tensor<4xf32>
}
// CHECK-LABEL: concatenate_const_2D_vertical
func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: ]>
%0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
%1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
// CHECK: return [[VAL]]
return %2 : tensor<2x2xi32>
}
// CHECK-LABEL: concatenate_const_2D_horizontal
func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
// CHECK-SAME: [0, 2], [1, 3]
// CHECK-SAME: ]>
%0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
%1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
// CHECK: return [[VAL]]
return %2 : tensor<2x2xi32>
}
// CHECK-LABEL: dynamic_slice_variable_start
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {