Add ScatterOp to HLO exporter.
PiperOrigin-RevId: 282493169 Change-Id: Iee8137d55f7ea4885015d27f615e6c243e4aee91
This commit is contained in:
parent
843c022f27
commit
f483d811f4
@ -756,6 +756,31 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def ScatterDimensionNumbers : StructAttr<"ScatterDimensionNumbers", HLO_Dialect,
|
||||
[StructFieldAttr<"update_window_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
||||
let description = "Structure of dimension information for scatter";
|
||||
}
|
||||
|
||||
def HLO_ScatterOp: HLO_Op<"scatter", [NoSideEffect]>, BASE_HLO_ScatterOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
HLO_Tensor:$scatter_indices,
|
||||
HLO_Tensor:$updates,
|
||||
ScatterDimensionNumbers:$scatter_dimension_numbers,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$update_computation);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Add broadcastable trait.
|
||||
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect]>, BASE_HLO_SelectOp {
|
||||
let arguments = (ins
|
||||
|
@ -784,6 +784,18 @@ class BASE_HLO_ReshapeOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_ScatterOp {
|
||||
string summary = "Scatter operator";
|
||||
|
||||
string description = [{
|
||||
Generates a result which is the value of the input array `operand`,
|
||||
with several slices (at indices specified by `scatter_indices`)
|
||||
updated with the values in `updates` using `update_computation`.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#scatter.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_SelectOp {
|
||||
string summary = "Select operator";
|
||||
|
||||
|
@ -49,6 +49,34 @@ using ::tensorflow::uint32;
|
||||
using ::tensorflow::uint64;
|
||||
using ::tensorflow::uint8;
|
||||
|
||||
// Passes through everything except for unique_ptr, on which it calls get().
|
||||
// This exists to allow the generated code to call XLA functions that take a raw
|
||||
// pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
|
||||
// as a pointer and there is otherwise no way to avoid a memory leak.
|
||||
template <typename T>
|
||||
T Unwrap(T t) {
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* Unwrap(const std::unique_ptr<T>& t) {
|
||||
return t.get();
|
||||
}
|
||||
|
||||
// Convert APInt into an int.
|
||||
// TODO(hpucha): This should be consolidated into a general place.
|
||||
static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
|
||||
|
||||
// Convert APFloat to double.
|
||||
static double ConvertAPFloat(llvm::APFloat value) {
|
||||
const auto& semantics = value.getSemantics();
|
||||
bool losesInfo = false;
|
||||
if (&semantics != &llvm::APFloat::IEEEdouble())
|
||||
value.convert(llvm::APFloat::IEEEdouble(),
|
||||
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
|
||||
return value.convertToDouble();
|
||||
}
|
||||
|
||||
static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
|
||||
auto values = attr.getValues<int64>();
|
||||
return {values.begin(), values.end()};
|
||||
@ -226,32 +254,30 @@ static xla::ComparisonDirection Convert_comparison_direction(
|
||||
.ValueOrDie();
|
||||
}
|
||||
|
||||
// Passes through everything except for unique_ptr, on which it calls get().
|
||||
// This exists to allow the generated code to call XLA functions that take a raw
|
||||
// pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
|
||||
// as a pointer and there is otherwise no way to avoid a memory leak.
|
||||
template <typename T>
|
||||
T Unwrap(T t) {
|
||||
return t;
|
||||
}
|
||||
static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
|
||||
mlir::xla_hlo::ScatterDimensionNumbers input) {
|
||||
xla::ScatterDimensionNumbers output;
|
||||
|
||||
template <typename T>
|
||||
T* Unwrap(const std::unique_ptr<T>& t) {
|
||||
return t.get();
|
||||
}
|
||||
auto update_window_dims = ConvertDenseIntAttr(input.update_window_dims());
|
||||
std::copy(update_window_dims.begin(), update_window_dims.end(),
|
||||
tensorflow::protobuf::RepeatedFieldBackInserter(
|
||||
output.mutable_update_window_dims()));
|
||||
|
||||
// Convert APInt into an int.
|
||||
// TODO(hpucha): This should be consolidated into a general place.
|
||||
static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
|
||||
auto inserted_window_dims = ConvertDenseIntAttr(input.inserted_window_dims());
|
||||
std::copy(inserted_window_dims.begin(), inserted_window_dims.end(),
|
||||
tensorflow::protobuf::RepeatedFieldBackInserter(
|
||||
output.mutable_inserted_window_dims()));
|
||||
|
||||
// Convert APFloat to double.
|
||||
static double ConvertAPFloat(llvm::APFloat value) {
|
||||
const auto& semantics = value.getSemantics();
|
||||
bool losesInfo = false;
|
||||
if (&semantics != &llvm::APFloat::IEEEdouble())
|
||||
value.convert(llvm::APFloat::IEEEdouble(),
|
||||
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
|
||||
return value.convertToDouble();
|
||||
auto scatter_dims_to_operand_dims =
|
||||
ConvertDenseIntAttr(input.scatter_dims_to_operand_dims());
|
||||
std::copy(scatter_dims_to_operand_dims.begin(),
|
||||
scatter_dims_to_operand_dims.end(),
|
||||
tensorflow::protobuf::RepeatedFieldBackInserter(
|
||||
output.mutable_scatter_dims_to_operand_dims()));
|
||||
|
||||
output.set_index_vector_dim(
|
||||
ConvertAPInt(input.index_vector_dim().getValue()));
|
||||
return output;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
@ -535,6 +561,22 @@ LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
xla::XlaComputation update_computation;
|
||||
if (failed(ctx.converter->LowerRegionAsComputation(&op.update_computation(),
|
||||
&update_computation))) {
|
||||
return failure();
|
||||
}
|
||||
xla::ScatterDimensionNumbers dimension_numbers =
|
||||
Convert_scatter_dimension_numbers(op.scatter_dimension_numbers());
|
||||
value_map[op] = xla::Scatter(
|
||||
value_map[op.operand()], value_map[op.scatter_indices()],
|
||||
value_map[op.updates()], update_computation, dimension_numbers,
|
||||
op.indices_are_sorted(), op.unique_indices());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
xla::XlaComputation select;
|
||||
|
27
tensorflow/compiler/mlir/xla/tests/translate/scatter.mlir
Normal file
27
tensorflow/compiler/mlir/xla/tests/translate/scatter.mlir
Normal file
@ -0,0 +1,27 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> {
|
||||
%0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
|
||||
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
scatter_dimension_numbers = {
|
||||
update_window_dims = dense<[1]> : tensor<1xi64>,
|
||||
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
index_vector_dim = 1 : i64
|
||||
},
|
||||
indices_are_sorted = true,
|
||||
unique_indices = true
|
||||
} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32>
|
||||
return %0 : tensor<200x100x300xf32>
|
||||
}
|
||||
|
||||
// CHECK: [[COMPUTATION:%.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[VAL_1:%.*]] = f32[200,100,300] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = s32[10,2] parameter(1)
|
||||
// CHECK: [[VAL_3:%.*]] = f32[10,300] parameter(2)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[10,2] [[VAL_2]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]]
|
Loading…
x
Reference in New Issue
Block a user