From caad1b7a45c593e83adbc2df0f099e783aff48e8 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Tue, 18 Feb 2020 12:10:36 -0800 Subject: [PATCH] Add import support for HLO Scatter op. PiperOrigin-RevId: 295792321 Change-Id: I6daf2b0b49d551a446d6e37b9e6f96fbd11fdbfa --- .../mlir/xla/hlo_function_importer.cc | 32 +++++++++++++++++++ .../compiler/mlir/xla/hlo_function_importer.h | 4 +++ .../mlir/xla/tests/translate/import.hlotxt | 31 ++++++++++++++++++ 3 files changed, 67 insertions(+) diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 6081f2e1461..bc9bdf49a39 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -370,6 +370,22 @@ StatusOr HloFunctionImporter::ImportInstruction( Convert(interior_padding)) .getOperation(); } + case HloOpcode::kScatter: { + auto scatter = static_cast(instruction); + attributes.push_back( + ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers())); + attributes.push_back(builder_->getNamedAttr( + "indices_are_sorted", + builder_->getBoolAttr(scatter->indices_are_sorted()))); + attributes.push_back(builder_->getNamedAttr( + "unique_indices", builder_->getBoolAttr(scatter->unique_indices()))); + + auto scatter_op = func_builder->create( + loc, result_type, operands, attributes); + TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(), + &scatter_op.update_computation())); + return scatter_op.getOperation(); + } case HloOpcode::kSetDimensionSize: { attributes.push_back(builder_->getNamedAttr( "dimension", builder_->getIntegerAttr(builder_->getIntegerType(32), @@ -844,6 +860,22 @@ mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers( return builder_->getNamedAttr("dimension_numbers", attr); } +mlir::NamedAttribute HloFunctionImporter::ConvertScatterDimensionNumbers( + const xla::ScatterDimensionNumbers& dnums) { + std::vector update_window_dims(dnums.update_window_dims().begin(), + dnums.update_window_dims().end()); + std::vector inserted_window_dims( + dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); + std::vector scatter_dims_to_operand_dims( + dnums.scatter_dims_to_operand_dims().begin(), + dnums.scatter_dims_to_operand_dims().end()); + auto attr = mlir::xla_hlo::ScatterDimensionNumbers::get( + Convert(update_window_dims), Convert(inserted_window_dims), + Convert(scatter_dims_to_operand_dims), + builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_); + return builder_->getNamedAttr("scatter_dimension_numbers", attr); +} + mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs( const std::vector>& source_target_pairs) { diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index d373e88e1c0..93c8e6e818c 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -121,6 +121,10 @@ class HloFunctionImporter { mlir::NamedAttribute ConvertGatherDimensionNumbers( const xla::GatherDimensionNumbers& dnums); + // Converts the scatter dimensions to attributes. + mlir::NamedAttribute ConvertScatterDimensionNumbers( + const xla::ScatterDimensionNumbers& dnums); + // Converts XLA instruction source target pairs to MLIR attribute. mlir::NamedAttribute ConvertSourceTargetPairs( const std::vector>& diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index a02db66cd47..b2dec8c950f 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -716,6 +716,37 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %Arg_0.1 = f32[] parameter(0) } +// Test scatter +%update_computation { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %sum = f32[] add(f32[] %lhs, f32[] %rhs) +} + +%test_scatter { + %input_tensor = f32[200,100,300] parameter(0) + %scatter_indices = s64[10,2] parameter(1) + %updates = f32[10,300] parameter(2) + ROOT %scatter = f32[200,100,300] scatter(f32[200,100,300] %input_tensor, s64[10,2] %scatter_indices, f32[10,300] %updates), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%update_computation +} + +// CHECK-LABEL: func @test_scatter +// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32> +// CHECK: "xla_hlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( { +// CHECK: ^bb0([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): +// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]] +// CHECK: "xla_hlo.return"([[ADD]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: indices_are_sorted = false +// CHECK-SAME: scatter_dimension_numbers = { +// CHECK-SAME: index_vector_dim = 1 : i64 +// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64> +// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64> +// CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64> +// CHECK-SAME: } +// CHECK-SAME: unique_indices = false + + // CHECK-LABEL: func @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { %test_select { %Arg_0.1 = pred[2,3] parameter(0)