Add support to lower element-wise HLO functions to LHLO on dynamic shapes.
This uses hand-written shape derivation functions and is not complete. The goal is to start experimenting with these and get some first code through. PiperOrigin-RevId: 296403560 Change-Id: I996cc5f862604ca76344076cb6699998757d4164
This commit is contained in:
parent
9b53275852
commit
45c98a790e
@ -142,6 +142,19 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "hlo_shape_derivation",
|
||||||
|
srcs = [],
|
||||||
|
hdrs = ["transforms/hlo_shape_derivation.h"],
|
||||||
|
deps = [
|
||||||
|
":hlo",
|
||||||
|
":lhlo",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
"@llvm-project//mlir:Transforms",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "lhlo_legalize_to_affine",
|
name = "lhlo_legalize_to_affine",
|
||||||
srcs = ["transforms/lhlo_legalize_to_affine.cc"],
|
srcs = ["transforms/lhlo_legalize_to_affine.cc"],
|
||||||
@ -218,9 +231,9 @@ cc_library(
|
|||||||
srcs = ["transforms/hlo_legalize_to_lhlo.cc"],
|
srcs = ["transforms/hlo_legalize_to_lhlo.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":hlo_shape_derivation",
|
||||||
":lhlo",
|
":lhlo",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal %s -o - | FileCheck %s --dump-input=always
|
// RUN: tf-opt -hlo-legalize-to-lhlo -lhlo-redundant-copies-removal -split-input-file %s -o - | FileCheck %s -dump-input-on-failure
|
||||||
|
|
||||||
// CHECK-LABEL: func @attrs
|
// CHECK-LABEL: func @attrs
|
||||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
@ -11,6 +11,8 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @func_op
|
// CHECK-LABEL: func @func_op
|
||||||
func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||||
@ -20,6 +22,8 @@ func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
|||||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @func_op_long
|
// CHECK-LABEL: func @func_op_long
|
||||||
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||||
@ -45,6 +49,8 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
|||||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store
|
// CHECK-LABEL: func @remove_lhlo_copy_op_created_from_tensor_store
|
||||||
func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: memref<f32>) {
|
func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: memref<f32>) {
|
||||||
%0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
%0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
@ -58,6 +64,8 @@ func @remove_lhlo_copy_op_created_from_tensor_store(%arg0: tensor<f32>, %arg1: t
|
|||||||
// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref<f32>
|
// CHECK-NOT: dealloc %[[ALLOC_OPERAND]] : memref<f32>
|
||||||
// CHECK: "xla_lhlo.terminator"() : () -> ()
|
// CHECK: "xla_lhlo.terminator"() : () -> ()
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @fusion
|
// CHECK-LABEL: func @fusion
|
||||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||||
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
@ -77,6 +85,8 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
|||||||
"xla_lhlo.terminator"() : () -> ()
|
"xla_lhlo.terminator"() : () -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @copy
|
// CHECK-LABEL: func @copy
|
||||||
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -87,6 +97,8 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @exp
|
// CHECK-LABEL: func @exp
|
||||||
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -97,6 +109,8 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @select
|
// CHECK-LABEL: func @select
|
||||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||||
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
@ -110,6 +124,8 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @compare
|
// CHECK-LABEL: func @compare
|
||||||
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||||
@ -122,6 +138,8 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @broadcast
|
// CHECK-LABEL: func @broadcast
|
||||||
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<5xf32>
|
%tensor_operand = tensor_load %operand : memref<5xf32>
|
||||||
@ -133,6 +151,8 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @dyn_broadcast
|
// CHECK-LABEL: func @dyn_broadcast
|
||||||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||||
@ -157,6 +177,8 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @iota
|
// CHECK-LABEL: func @iota
|
||||||
func @iota(%result: memref<10xi32>) {
|
func @iota(%result: memref<10xi32>) {
|
||||||
%tensor_result = "xla_hlo.iota"()
|
%tensor_result = "xla_hlo.iota"()
|
||||||
@ -166,6 +188,8 @@ func @iota(%result: memref<10xi32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @abs
|
// CHECK-LABEL: func @abs
|
||||||
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -176,6 +200,8 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @ceil
|
// CHECK-LABEL: func @ceil
|
||||||
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -186,6 +212,8 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @convert
|
// CHECK-LABEL: func @convert
|
||||||
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -196,6 +224,8 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @cos
|
// CHECK-LABEL: func @cos
|
||||||
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -206,6 +236,8 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @neg
|
// CHECK-LABEL: func @neg
|
||||||
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -216,6 +248,8 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @sign
|
// CHECK-LABEL: func @sign
|
||||||
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -226,6 +260,8 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @tanh
|
// CHECK-LABEL: func @tanh
|
||||||
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
@ -236,6 +272,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @remainder
|
// CHECK-LABEL: func @remainder
|
||||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||||
@ -246,3 +284,47 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
|||||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Dynamic shape binary element-wise operation.
|
||||||
|
// CHECK-LABEL: func @add_dyn
|
||||||
|
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||||
|
%result = "xla_hlo.add"(%lhs, %rhs)
|
||||||
|
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||||
|
// CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||||
|
// CHECK: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[IC0]], %[[IC1]]) : (i64, i64) -> tensor<2xi64>
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
||||||
|
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||||
|
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
||||||
|
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||||
|
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||||
|
// CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Dynamic shape unary element-wise operation.
|
||||||
|
// CHECK-LABEL: func @tanh_dyn
|
||||||
|
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||||
|
%result = "xla_hlo.tanh"(%arg0)
|
||||||
|
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||||
|
// CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||||
|
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||||
|
// CHECK: %[[SHAPE:.*]] = "xla_hlo.scalars_to_dimension_tensor"(%[[IC0]], %[[IC1]]) : (i64, i64) -> tensor<2xi64>
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
||||||
|
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||||
|
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
||||||
|
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||||
|
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||||
|
// CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
|
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
|
||||||
|
|
||||||
@ -127,9 +128,24 @@ class HloToLhloOpConverter : public ConversionPattern {
|
|||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
const auto& original_results = op->getResults();
|
const auto& original_results = op->getResults();
|
||||||
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||||
for (auto result : original_results) {
|
for (auto result : llvm::enumerate(original_results)) {
|
||||||
buffer_args.push_back(
|
RankedTensorType resultType =
|
||||||
InsertAllocAndDealloc(op->getLoc(), result, &rewriter));
|
result.value().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!resultType) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
if (resultType.hasStaticShape()) {
|
||||||
|
buffer_args.push_back(
|
||||||
|
InsertAllocAndDealloc(op->getLoc(), result.value(), &rewriter));
|
||||||
|
} else {
|
||||||
|
Value shape_value = ShapeDerivation<HloOpTy>::impl::deriveShapeFromOp(
|
||||||
|
op, result.index(), &rewriter);
|
||||||
|
if (!shape_value) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||||
|
op->getLoc(), result.value(), shape_value, &rewriter));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
|
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
|
||||||
op->getAttrs());
|
op->getAttrs());
|
||||||
@ -320,6 +336,7 @@ struct HloLegalizeToLhlo : public ModulePass<HloLegalizeToLhlo> {
|
|||||||
target.addIllegalOp<mlir::TensorLoadOp>();
|
target.addIllegalOp<mlir::TensorLoadOp>();
|
||||||
target.addIllegalOp<mlir::TensorStoreOp>();
|
target.addIllegalOp<mlir::TensorStoreOp>();
|
||||||
target.addLegalOp<ModuleTerminatorOp>();
|
target.addLegalOp<ModuleTerminatorOp>();
|
||||||
|
target.addLegalOp<ScalarsToDimensionTensorOp>();
|
||||||
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
|
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
|
||||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||||
auto inputs = op.getType().getInputs();
|
auto inputs = op.getType().getInputs();
|
||||||
|
130
tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h
Normal file
130
tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||||
|
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
// This file contains implementations for shape derivation functions that,
|
||||||
|
// given some operation and a result number, produce IR that computes the
|
||||||
|
// shape of the given result at runtime based on operands of the provided
|
||||||
|
// operation.
|
||||||
|
// These should be generated at some point based on annotations on the HLO
|
||||||
|
// using the new shape dialect. While this is still in the works, we hardcode
|
||||||
|
// the expected IR here to unblock progress.
|
||||||
|
// The implementation is based on templates to allow for using these derivation
|
||||||
|
// functions in templated code.
|
||||||
|
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
struct UnknownShape {
|
||||||
|
// Default shape derivation function that simply fails with a runtime error.
|
||||||
|
static Value deriveShapeFromOp(Operation* op, int operand_position,
|
||||||
|
ConversionPatternRewriter* rewriter) {
|
||||||
|
op->emitOpError()
|
||||||
|
<< "dynamic result shapes cannot be derived for this operation";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SameShapeAsFirstOperand {
|
||||||
|
// Shape derivation function that computes the shape of the result based on
|
||||||
|
// the first argument. For a 2-dimensional input tensor, this produces IR of
|
||||||
|
// the form
|
||||||
|
//
|
||||||
|
// %0 = dim %arg0, 0 : memref<?x?xf32>
|
||||||
|
// %1 = index_cast %0 : index to i64
|
||||||
|
// %2 = dim %arg0, 1 : memref<?x?xf32>
|
||||||
|
// %3 = index_cast %2 : index to i64
|
||||||
|
// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3)
|
||||||
|
// : (i64, i64) -> tensor<2xi64>
|
||||||
|
//
|
||||||
|
// and returns %4 as the shape value.
|
||||||
|
static Value deriveShapeFromOp(Operation* op, int result_postion,
|
||||||
|
ConversionPatternRewriter* rewriter) {
|
||||||
|
Value operand = op->getOperand(0);
|
||||||
|
ShapedType operand_type = operand.getType().dyn_cast<ShapedType>();
|
||||||
|
if (!operand_type) {
|
||||||
|
op->emitOpError() << "first operand has no shaped type";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
SmallVector<Value, 4> shape_values;
|
||||||
|
shape_values.reserve(operand_type.getRank());
|
||||||
|
auto shape_scalar_type = rewriter->getIntegerType(64);
|
||||||
|
for (auto element : llvm::enumerate(operand_type.getShape())) {
|
||||||
|
if (element.value() == ShapedType::kDynamicSize) {
|
||||||
|
Value dim = rewriter->create<DimOp>(loc, operand, element.index());
|
||||||
|
shape_values.push_back(
|
||||||
|
rewriter->create<IndexCastOp>(loc, dim, shape_scalar_type));
|
||||||
|
} else {
|
||||||
|
shape_values.push_back(rewriter->create<ConstantOp>(
|
||||||
|
loc, rewriter->getI64IntegerAttr(element.value())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rewriter->create<ScalarsToDimensionTensorOp>(
|
||||||
|
loc, RankedTensorType::get({operand_type.getRank()}, shape_scalar_type),
|
||||||
|
shape_values);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace impl
|
||||||
|
|
||||||
|
// Default template to cover HLO operations whose shape derivation is unknown.
|
||||||
|
template <typename HloOpTy>
|
||||||
|
struct ShapeDerivation {
|
||||||
|
using impl = impl::UnknownShape;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Element-wise operations that have the shape of their first operand.
|
||||||
|
|
||||||
|
#define SAME_SHAPE_AS_FIRST_OPERAND(Op) \
|
||||||
|
template <> \
|
||||||
|
struct ShapeDerivation<Op> { \
|
||||||
|
using impl = impl::SameShapeAsFirstOperand; \
|
||||||
|
};
|
||||||
|
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(AbsOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(AddOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(AndOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(CeilOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(CosOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(DivOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(ExpOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(MaxOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(MinOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(MulOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(NegOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(RemOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(SubOp)
|
||||||
|
SAME_SHAPE_AS_FIRST_OPERAND(TanhOp)
|
||||||
|
|
||||||
|
#undef SAME_SHAPE_AS_FIRST_OPERAND
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_HLO_SHAPE_DERIVATION_H_
|
Loading…
Reference in New Issue
Block a user