Add import support for HLO Scatter op.

PiperOrigin-RevId: 295792321
Change-Id: I6daf2b0b49d551a446d6e37b9e6f96fbd11fdbfa
This commit is contained in:
Prakalp Srivastava 2020-02-18 12:10:36 -08:00 committed by TensorFlower Gardener
parent 19ac5f4f6c
commit caad1b7a45
3 changed files with 67 additions and 0 deletions
tensorflow/compiler/mlir/xla

View File

@ -370,6 +370,22 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
Convert(interior_padding))
.getOperation();
}
case HloOpcode::kScatter: {
auto scatter = static_cast<HloScatterInstruction*>(instruction);
attributes.push_back(
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers()));
attributes.push_back(builder_->getNamedAttr(
"indices_are_sorted",
builder_->getBoolAttr(scatter->indices_are_sorted())));
attributes.push_back(builder_->getNamedAttr(
"unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(),
&scatter_op.update_computation()));
return scatter_op.getOperation();
}
case HloOpcode::kSetDimensionSize: {
attributes.push_back(builder_->getNamedAttr(
"dimension", builder_->getIntegerAttr(builder_->getIntegerType(32),
@ -844,6 +860,22 @@ mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers(
return builder_->getNamedAttr("dimension_numbers", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertScatterDimensionNumbers(
const xla::ScatterDimensionNumbers& dnums) {
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
std::vector<int64_t> inserted_window_dims(
dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
std::vector<int64_t> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
auto attr = mlir::xla_hlo::ScatterDimensionNumbers::get(
Convert(update_window_dims), Convert(inserted_window_dims),
Convert(scatter_dims_to_operand_dims),
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
return builder_->getNamedAttr("scatter_dimension_numbers", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
source_target_pairs) {

View File

@ -121,6 +121,10 @@ class HloFunctionImporter {
mlir::NamedAttribute ConvertGatherDimensionNumbers(
const xla::GatherDimensionNumbers& dnums);
// Converts the scatter dimensions to attributes.
mlir::NamedAttribute ConvertScatterDimensionNumbers(
const xla::ScatterDimensionNumbers& dnums);
// Converts XLA instruction source target pairs to MLIR attribute.
mlir::NamedAttribute ConvertSourceTargetPairs(
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&

View File

@ -716,6 +716,37 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT %Arg_0.1 = f32[] parameter(0)
}
// Test scatter
%update_computation {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %sum = f32[] add(f32[] %lhs, f32[] %rhs)
}
%test_scatter {
%input_tensor = f32[200,100,300] parameter(0)
%scatter_indices = s64[10,2] parameter(1)
%updates = f32[10,300] parameter(2)
ROOT %scatter = f32[200,100,300] scatter(f32[200,100,300] %input_tensor, s64[10,2] %scatter_indices, f32[10,300] %updates), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%update_computation
}
// CHECK-LABEL: func @test_scatter
// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32>
// CHECK: "xla_hlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( {
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]]
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: indices_are_sorted = false
// CHECK-SAME: scatter_dimension_numbers = {
// CHECK-SAME: index_vector_dim = 1 : i64
// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: update_window_dims = dense<1> : tensor<1xi64>
// CHECK-SAME: }
// CHECK-SAME: unique_indices = false
// CHECK-LABEL: func @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
%test_select {
%Arg_0.1 = pred[2,3] parameter(0)