Add XLA HLO -> LMHLO conversion for SelectAndScatter
- Add conversion and unit test PiperOrigin-RevId: 338687838 Change-Id: I26eba3bb60236151147f348ecbfb09fafd33a475
This commit is contained in:
parent
080f49a6a5
commit
6bbe9adb45
@ -2,7 +2,7 @@
|
||||
|
||||
HloModule TestModule
|
||||
|
||||
// CHECK: func @TestComputation
|
||||
// CHECK-LABEL: func @TestComputation
|
||||
|
||||
FusedComputation {
|
||||
// CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>}
|
||||
@ -24,7 +24,7 @@ update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
|
||||
ROOT rhs = s32[] parameter(1)
|
||||
}
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK: "lmhlo.scatter"
|
||||
// CHECK: ^bb0(%[[ARG5:.*]]: tensor<i32>, %[[ARG6:.*]]: tensor<i32>):
|
||||
// CHECK: "mhlo.return"(%[[ARG6]])
|
||||
@ -46,3 +46,38 @@ ENTRY main {
|
||||
scatter_dims_to_operand_dims={0},
|
||||
index_vector_dim=1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule SelectAndScatter
|
||||
|
||||
%ge_F32 (lhs.5: f32[], rhs.6: f32[]) -> pred[] {
|
||||
%lhs.5 = f32[] parameter(0)
|
||||
%rhs.6 = f32[] parameter(1)
|
||||
ROOT %compare.7 = pred[] compare(f32[] %lhs.5, f32[] %rhs.6), direction=GE
|
||||
}
|
||||
|
||||
%add_F32 (lhs.9: f32[], rhs.10: f32[]) -> f32[] {
|
||||
%lhs.9 = f32[] parameter(0)
|
||||
%rhs.10 = f32[] parameter(1)
|
||||
ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK: "lmhlo.select_and_scatter"
|
||||
// CHECK: ^bb0(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>):
|
||||
// CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG0]], %[[ARG1]]) {comparison_direction = "GE"}
|
||||
// CHECK: "mhlo.return"(%[[COMPARE]]) : (tensor<i1>) -> ()
|
||||
// CHECK: ^bb0(%[[ARG2:.*]]: tensor<f32>, %[[ARG3:.*]]: tensor<f32>):
|
||||
// CHECK: %[[ADD:.*]] = mhlo.add %[[ARG2]], %[[ARG3]]
|
||||
// CHECK: "mhlo.return"(%[[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: padding = dense<0> : tensor<1xi64>
|
||||
// CHECK: window_dimensions = dense<3> : tensor<1xi64>
|
||||
// CHECK: window_strides = dense<3> : tensor<1xi64>
|
||||
// CHECK: (memref<6xf32>, memref<2xf32>, memref<f32>, memref<6xf32>) -> ()
|
||||
ENTRY main () -> f32[6] {
|
||||
%operand = f32[6]{0} parameter(0)
|
||||
%source = f32[2]{0} parameter(1)
|
||||
%init = f32[] parameter(2)
|
||||
ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32
|
||||
}
|
||||
|
@ -407,9 +407,9 @@ LhloDialectEmitter::GetScatterDimensionNumbers(HloInstruction* instr) {
|
||||
const ::xla::ScatterDimensionNumbers& xla_scatter_dim =
|
||||
scatter_instr->scatter_dimension_numbers();
|
||||
auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbers::get(
|
||||
getI64DenseElementsAttr(xla_scatter_dim.update_window_dims()),
|
||||
getI64DenseElementsAttr(xla_scatter_dim.inserted_window_dims()),
|
||||
getI64DenseElementsAttr(xla_scatter_dim.scatter_dims_to_operand_dims()),
|
||||
GetI64DenseElementsAttr(xla_scatter_dim.update_window_dims()),
|
||||
GetI64DenseElementsAttr(xla_scatter_dim.inserted_window_dims()),
|
||||
GetI64DenseElementsAttr(xla_scatter_dim.scatter_dims_to_operand_dims()),
|
||||
builder_.getI64IntegerAttr(xla_scatter_dim.index_vector_dim()),
|
||||
module_.getContext());
|
||||
return scatter_dimension_numbers;
|
||||
@ -443,6 +443,43 @@ Status LhloDialectEmitter::HandleScatter(HloInstruction* instr) {
|
||||
return EmitScatterOp(instr).status();
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
|
||||
HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto select_and_scatter,
|
||||
CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr));
|
||||
|
||||
// copy attributes
|
||||
auto* select_and_scatter_instr =
|
||||
::xla::Cast<::xla::HloSelectAndScatterInstruction>(instr);
|
||||
const ::xla::Window& window = select_and_scatter_instr->window();
|
||||
|
||||
select_and_scatter.window_dimensionsAttr(
|
||||
GetWindowElements(window, [](const ::xla::WindowDimension& dim) {
|
||||
return static_cast<int64_t>(dim.size());
|
||||
}));
|
||||
select_and_scatter.window_stridesAttr(
|
||||
GetWindowElements(window, [](const ::xla::WindowDimension& dim) {
|
||||
return static_cast<int64_t>(dim.stride());
|
||||
}));
|
||||
select_and_scatter.paddingAttr(
|
||||
GetWindowElements(window, [](const ::xla::WindowDimension& dim) {
|
||||
return static_cast<int64_t>(dim.padding_low());
|
||||
}));
|
||||
|
||||
// import select and scatter computation as region
|
||||
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
|
||||
*select_and_scatter_instr->select(), &select_and_scatter.select(),
|
||||
&builder_));
|
||||
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
|
||||
*select_and_scatter_instr->scatter(), &select_and_scatter.scatter(),
|
||||
&builder_));
|
||||
return select_and_scatter;
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleSelectAndScatter(HloInstruction* instr) {
|
||||
return EmitSelectAndScatterOp(instr).status();
|
||||
}
|
||||
|
||||
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||
const ::xla::ShapeIndex& shape_index) {
|
||||
|
@ -49,17 +49,30 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
::xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<mhlo::ScatterDimensionNumbers> GetScatterDimensionNumbers(
|
||||
::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp(
|
||||
::xla::HloInstruction* instr);
|
||||
|
||||
private:
|
||||
template <typename OpType>
|
||||
::xla::StatusOr<OpType> CreateOpWithoutAttrs(::xla::HloInstruction* instr);
|
||||
|
||||
template <typename T>
|
||||
DenseIntElementsAttr getI64DenseElementsAttr(const T& container) {
|
||||
DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
|
||||
return builder_.getI64TensorAttr(
|
||||
{container.data(), static_cast<size_t>(container.size())});
|
||||
}
|
||||
|
||||
DenseIntElementsAttr GetWindowElements(
|
||||
const ::xla::Window& window,
|
||||
std::function<int64_t(const xla::WindowDimension& dim)> getter) {
|
||||
llvm::SmallVector<int64_t, 4> elements;
|
||||
elements.reserve(window.dimensions_size());
|
||||
for (const ::xla::WindowDimension& dim : window.dimensions()) {
|
||||
elements.push_back(getter(dim));
|
||||
}
|
||||
return GetI64DenseElementsAttr(elements);
|
||||
}
|
||||
|
||||
tensorflow::Status DefaultAction(::xla::HloInstruction* instr) final;
|
||||
|
||||
// Computation parameters don't need any specific handling when they are
|
||||
@ -71,6 +84,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
tensorflow::Status HandleSort(::xla::HloInstruction* instr) final;
|
||||
tensorflow::Status HandleFusion(::xla::HloInstruction* instr) final;
|
||||
tensorflow::Status HandleScatter(::xla::HloInstruction* instr) final;
|
||||
tensorflow::Status HandleSelectAndScatter(::xla::HloInstruction* instr) final;
|
||||
|
||||
// Helper function that recursively visits the tuple structure in
|
||||
// `current_shape`, and reconstruct a matching lmhlo::TupleOp.
|
||||
|
Loading…
x
Reference in New Issue
Block a user