Add ScatterOp to HLO exporter.

PiperOrigin-RevId: 282493169
Change-Id: Iee8137d55f7ea4885015d27f615e6c243e4aee91
This commit is contained in:
Prakalp Srivastava 2019-11-25 21:32:25 -08:00 committed by TensorFlower Gardener
parent 843c022f27
commit f483d811f4
4 changed files with 129 additions and 23 deletions

View File

@ -756,6 +756,31 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
let hasCustomHLOConverter = 1;
}
def ScatterDimensionNumbers : StructAttr<"ScatterDimensionNumbers", HLO_Dialect,
[StructFieldAttr<"update_window_dims", I64ElementsAttr>,
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for scatter";
}
def HLO_ScatterOp: HLO_Op<"scatter", [NoSideEffect]>, BASE_HLO_ScatterOp {
let arguments = (ins
HLO_Tensor:$operand,
HLO_Tensor:$scatter_indices,
HLO_Tensor:$updates,
ScatterDimensionNumbers:$scatter_dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
);
let regions = (region SizedRegion<1>:$update_computation);
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
}
// TODO(jpienaar): Add broadcastable trait.
def HLO_SelectOp: HLO_Op<"select", [NoSideEffect]>, BASE_HLO_SelectOp {
let arguments = (ins

View File

@ -784,6 +784,18 @@ class BASE_HLO_ReshapeOp {
}];
}
class BASE_HLO_ScatterOp {
string summary = "Scatter operator";
string description = [{
Generates a result which is the value of the input array `operand`,
with several slices (at indices specified by `scatter_indices`)
updated with the values in `updates` using `update_computation`.
See https://www.tensorflow.org/xla/operation_semantics#scatter.
}];
}
class BASE_HLO_SelectOp {
string summary = "Select operator";

View File

@ -49,6 +49,34 @@ using ::tensorflow::uint32;
using ::tensorflow::uint64;
using ::tensorflow::uint8;
// Passes through everything except for unique_ptr, on which it calls get().
// This exists to allow the generated code to call XLA functions that take a raw
// pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
// as a pointer and there is otherwise no way to avoid a memory leak.
template <typename T>
T Unwrap(T t) {
return t;
}
template <typename T>
T* Unwrap(const std::unique_ptr<T>& t) {
return t.get();
}
// Convert APInt into an int.
// TODO(hpucha): This should be consolidated into a general place.
static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
// Convert APFloat to double.
static double ConvertAPFloat(llvm::APFloat value) {
const auto& semantics = value.getSemantics();
bool losesInfo = false;
if (&semantics != &llvm::APFloat::IEEEdouble())
value.convert(llvm::APFloat::IEEEdouble(),
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
return value.convertToDouble();
}
static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
auto values = attr.getValues<int64>();
return {values.begin(), values.end()};
@ -226,32 +254,30 @@ static xla::ComparisonDirection Convert_comparison_direction(
.ValueOrDie();
}
// Passes through everything except for unique_ptr, on which it calls get().
// This exists to allow the generated code to call XLA functions that take a raw
// pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
// as a pointer and there is otherwise no way to avoid a memory leak.
template <typename T>
T Unwrap(T t) {
return t;
}
static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
mlir::xla_hlo::ScatterDimensionNumbers input) {
xla::ScatterDimensionNumbers output;
template <typename T>
T* Unwrap(const std::unique_ptr<T>& t) {
return t.get();
}
auto update_window_dims = ConvertDenseIntAttr(input.update_window_dims());
std::copy(update_window_dims.begin(), update_window_dims.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_update_window_dims()));
// Convert APInt into an int.
// TODO(hpucha): This should be consolidated into a general place.
static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
auto inserted_window_dims = ConvertDenseIntAttr(input.inserted_window_dims());
std::copy(inserted_window_dims.begin(), inserted_window_dims.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_inserted_window_dims()));
// Convert APFloat to double.
static double ConvertAPFloat(llvm::APFloat value) {
const auto& semantics = value.getSemantics();
bool losesInfo = false;
if (&semantics != &llvm::APFloat::IEEEdouble())
value.convert(llvm::APFloat::IEEEdouble(),
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
return value.convertToDouble();
auto scatter_dims_to_operand_dims =
ConvertDenseIntAttr(input.scatter_dims_to_operand_dims());
std::copy(scatter_dims_to_operand_dims.begin(),
scatter_dims_to_operand_dims.end(),
tensorflow::protobuf::RepeatedFieldBackInserter(
output.mutable_scatter_dims_to_operand_dims()));
output.set_index_vector_dim(
ConvertAPInt(input.index_vector_dim().getValue()));
return output;
}
namespace mlir {
@ -535,6 +561,22 @@ LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) {
return success();
}
LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaComputation update_computation;
if (failed(ctx.converter->LowerRegionAsComputation(&op.update_computation(),
&update_computation))) {
return failure();
}
xla::ScatterDimensionNumbers dimension_numbers =
Convert_scatter_dimension_numbers(op.scatter_dimension_numbers());
value_map[op] = xla::Scatter(
value_map[op.operand()], value_map[op.scatter_indices()],
value_map[op.updates()], update_computation, dimension_numbers,
op.indices_are_sorted(), op.unique_indices());
return success();
}
LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaComputation select;

View File

@ -0,0 +1,27 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> {
%0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
}) {
scatter_dimension_numbers = {
update_window_dims = dense<[1]> : tensor<1xi64>,
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
index_vector_dim = 1 : i64
},
indices_are_sorted = true,
unique_indices = true
} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32>
return %0 : tensor<200x100x300xf32>
}
// CHECK: [[COMPUTATION:%.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[]
// CHECK-LABEL: ENTRY
// CHECK: [[VAL_1:%.*]] = f32[200,100,300] parameter(0)
// CHECK: [[VAL_2:%.*]] = s32[10,2] parameter(1)
// CHECK: [[VAL_3:%.*]] = f32[10,300] parameter(2)
// CHECK-LABEL: ROOT
// CHECK-SAME: f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[10,2] [[VAL_2]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]]