Add import support for HLO Scatter op.
PiperOrigin-RevId: 295792321 Change-Id: I6daf2b0b49d551a446d6e37b9e6f96fbd11fdbfa
This commit is contained in:
parent
19ac5f4f6c
commit
caad1b7a45
tensorflow/compiler/mlir/xla
@ -370,6 +370,22 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
Convert(interior_padding))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kScatter: {
|
||||
auto scatter = static_cast<HloScatterInstruction*>(instruction);
|
||||
attributes.push_back(
|
||||
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers()));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"indices_are_sorted",
|
||||
builder_->getBoolAttr(scatter->indices_are_sorted())));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
|
||||
|
||||
auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(),
|
||||
&scatter_op.update_computation()));
|
||||
return scatter_op.getOperation();
|
||||
}
|
||||
case HloOpcode::kSetDimensionSize: {
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"dimension", builder_->getIntegerAttr(builder_->getIntegerType(32),
|
||||
@ -844,6 +860,22 @@ mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers(
|
||||
return builder_->getNamedAttr("dimension_numbers", attr);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertScatterDimensionNumbers(
|
||||
const xla::ScatterDimensionNumbers& dnums) {
|
||||
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
|
||||
dnums.update_window_dims().end());
|
||||
std::vector<int64_t> inserted_window_dims(
|
||||
dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
|
||||
std::vector<int64_t> scatter_dims_to_operand_dims(
|
||||
dnums.scatter_dims_to_operand_dims().begin(),
|
||||
dnums.scatter_dims_to_operand_dims().end());
|
||||
auto attr = mlir::xla_hlo::ScatterDimensionNumbers::get(
|
||||
Convert(update_window_dims), Convert(inserted_window_dims),
|
||||
Convert(scatter_dims_to_operand_dims),
|
||||
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
|
||||
return builder_->getNamedAttr("scatter_dimension_numbers", attr);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
|
||||
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
|
||||
source_target_pairs) {
|
||||
|
@ -121,6 +121,10 @@ class HloFunctionImporter {
|
||||
mlir::NamedAttribute ConvertGatherDimensionNumbers(
|
||||
const xla::GatherDimensionNumbers& dnums);
|
||||
|
||||
// Converts the scatter dimensions to attributes.
|
||||
mlir::NamedAttribute ConvertScatterDimensionNumbers(
|
||||
const xla::ScatterDimensionNumbers& dnums);
|
||||
|
||||
// Converts XLA instruction source target pairs to MLIR attribute.
|
||||
mlir::NamedAttribute ConvertSourceTargetPairs(
|
||||
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
|
||||
|
@ -716,6 +716,37 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
ROOT %Arg_0.1 = f32[] parameter(0)
|
||||
}
|
||||
|
||||
// Test scatter
|
||||
%update_computation {
|
||||
%lhs = f32[] parameter(0)
|
||||
%rhs = f32[] parameter(1)
|
||||
ROOT %sum = f32[] add(f32[] %lhs, f32[] %rhs)
|
||||
}
|
||||
|
||||
%test_scatter {
|
||||
%input_tensor = f32[200,100,300] parameter(0)
|
||||
%scatter_indices = s64[10,2] parameter(1)
|
||||
%updates = f32[10,300] parameter(2)
|
||||
ROOT %scatter = f32[200,100,300] scatter(f32[200,100,300] %input_tensor, s64[10,2] %scatter_indices, f32[10,300] %updates), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%update_computation
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_scatter
|
||||
// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32>
|
||||
// CHECK: "xla_hlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( {
|
||||
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
|
||||
// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]]
|
||||
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: })
|
||||
// CHECK-SAME: indices_are_sorted = false
|
||||
// CHECK-SAME: scatter_dimension_numbers = {
|
||||
// CHECK-SAME: index_vector_dim = 1 : i64
|
||||
// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: }
|
||||
// CHECK-SAME: unique_indices = false
|
||||
|
||||
|
||||
// CHECK-LABEL: func @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
%test_select {
|
||||
%Arg_0.1 = pred[2,3] parameter(0)
|
||||
|
Loading…
Reference in New Issue
Block a user