Merge pull request #31904 from nouiz:scatter_no_atomic

PiperOrigin-RevId: 266930030
This commit is contained in:
TensorFlower Gardener 2019-09-03 08:12:45 -07:00
commit 89ce4c791c
16 changed files with 146 additions and 26 deletions

View File

@ -1746,11 +1746,13 @@ XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates, const XlaOp& updates,
const XlaComputation& update_computation, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, const ScatterDimensionNumbers& dimension_numbers,
bool indices_are_sorted) { bool indices_are_sorted, bool unique_indices) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr; HloInstructionProto instr;
instr.set_indices_are_sorted(indices_are_sorted); instr.set_indices_are_sorted(indices_are_sorted);
instr.set_unique_indices(unique_indices);
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape, TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape,
GetShape(scatter_indices)); GetShape(scatter_indices));
@ -3390,10 +3392,10 @@ XlaOp Gather(const XlaOp input, const XlaOp start_indices,
XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices, XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
const XlaOp updates, const XlaComputation& update_computation, const XlaOp updates, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, const ScatterDimensionNumbers& dimension_numbers,
bool indices_are_sorted) { bool indices_are_sorted, bool unique_indices) {
return input.builder()->Scatter(input, scatter_indices, updates, return input.builder()->Scatter(input, scatter_indices, updates,
update_computation, dimension_numbers, update_computation, dimension_numbers,
indices_are_sorted); indices_are_sorted, unique_indices);
} }
void Send(const XlaOp operand, const ChannelHandle& handle) { void Send(const XlaOp operand, const ChannelHandle& handle) {

View File

@ -592,7 +592,7 @@ class XlaBuilder {
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
const XlaOp& updates, const XlaComputation& update_computation, const XlaOp& updates, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, const ScatterDimensionNumbers& dimension_numbers,
bool indices_are_sorted = false); bool indices_are_sorted = false, bool unique_indices = false);
void Send(const XlaOp& operand, const ChannelHandle& handle); void Send(const XlaOp& operand, const ChannelHandle& handle);
XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
@ -1010,7 +1010,7 @@ class XlaBuilder {
friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
const XlaComputation& update_computation, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, const ScatterDimensionNumbers& dimension_numbers,
bool indices_are_sorted); bool indices_are_sorted, bool unique_indices);
friend void Send(XlaOp operand, const ChannelHandle& handle); friend void Send(XlaOp operand, const ChannelHandle& handle);
friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
const ChannelHandle& handle); const ChannelHandle& handle);
@ -1869,7 +1869,7 @@ XlaOp Gather(XlaOp input, XlaOp start_indices,
XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
const XlaComputation& update_computation, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, const ScatterDimensionNumbers& dimension_numbers,
bool indices_are_sorted = false); bool indices_are_sorted = false, bool unique_indices = false);
// Enqueues a Send node onto the computation for device-to-device // Enqueues a Send node onto the computation for device-to-device
// communication. This operation sends the given operand to // communication. This operation sends the given operand to

View File

@ -1382,6 +1382,9 @@ For a more intuitive description, see the "Informal Description" section below.
| `indices_are_sorted` | `bool` | Whether the indices are | | `indices_are_sorted` | `bool` | Whether the indices are |
: : : guaranteed to be sorted by : : : : guaranteed to be sorted by :
: : : the caller. : : : : the caller. :
| `unique_indices` | `bool` | Whether the indices are |
: : : guaranteed to be unique by :
: : : the caller. :
For convenience, we label dimensions in the output array not in `offset_dims` For convenience, we label dimensions in the output array not in `offset_dims`
as `batch_dims`. as `batch_dims`.
@ -1450,6 +1453,11 @@ If `indices_are_sorted` is set to true then XLA can assume that `start_indices`
are sorted (in ascending `start_index_map` order) by the user. If they are not are sorted (in ascending `start_index_map` order) by the user. If they are not
then the semantics is implementation defined. then the semantics is implementation defined.
If `unique_indices` is set to true then XLA can assume that all element
scattered to are unique. So XLA could use non-atomic operations. If
`unique_indices` is set to true and the indices being scattered to are not
unique then the semantics is implementation defined.
### Informal Description and Examples ### Informal Description and Examples
Informally, every index `Out` in the output array corresponds to an element `E` Informally, every index `Out` in the output array corresponds to an element `E`

View File

@ -1544,11 +1544,12 @@ class ComputationBuilder(object):
updates, updates,
update_computation, update_computation,
dimension_numbers, dimension_numbers,
indices_are_sorted=False): indices_are_sorted=False,
unique_indices=False):
"""Enqueues a Scatter operation onto the computation.""" """Enqueues a Scatter operation onto the computation."""
return ops.Scatter(a, scatter_indices, updates, return ops.Scatter(a, scatter_indices, updates,
update_computation.computation, dimension_numbers, update_computation.computation, dimension_numbers,
indices_are_sorted) indices_are_sorted, unique_indices)
def Fft(self, operand, fft_type, fft_lengths): def Fft(self, operand, fft_type, fft_lengths):
"""Enqueues a FFT operation onto the computation.""" """Enqueues a FFT operation onto the computation."""

View File

@ -965,8 +965,15 @@ Status IrEmitterUnnested::EmitScatter(
updates->shape().element_type(), module_)); updates->shape().element_type(), module_));
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index)); TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
Store(input_ir_value, input_address); Store(input_ir_value, input_address);
if (!scatter->unique_indices()) {
return EmitAtomicOperationForNestedComputation( return EmitAtomicOperationForNestedComputation(
*scatter->to_apply(), output_address, input_address); *scatter->to_apply(), output_address, input_address);
} else {
return EmitCallToNestedComputation(*scatter->to_apply(),
{output_address, input_address},
output_address);
}
}; };
// Launch a kernel that reads every element in the updates tensor. We could // Launch a kernel that reads every element in the updates tensor. We could

View File

@ -188,7 +188,12 @@ class IrEmitterUnnested : public IrEmitter,
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in // Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
// the process. `scatter` may be fused, scatter indices are taken from // the process. `scatter` may be fused, scatter indices are taken from
// `scatter_indices_gen`, updates from`updates_gen`. The output buffer is // `scatter_indices_gen`, updates from`updates_gen`. The output buffer is
// expected to have the operand values in it already. // expected to have the operand values in it already. If unique_indices
// is false, we will use an atomic update. Using false for unique_indices
// is safe only when it is guaranteed that there are no duplicate
// indices.
// When using unique_indices=true, it is the caller's responsibility to
// ensure there is no overlap.
Status EmitScatter(Thunk* thunk, HloInstruction* scatter, Status EmitScatter(Thunk* thunk, HloInstruction* scatter,
const llvm_ir::ElementGenerator& scatter_indices_gen, const llvm_ir::ElementGenerator& scatter_indices_gen,
const llvm_ir::ElementGenerator& updates_gen); const llvm_ir::ElementGenerator& updates_gen);

View File

@ -53,6 +53,33 @@ CHECK: store atomic{{.*}}unordered, align 4
)"); )");
} }
TEST_F(GpuAtomicTest, TestStoreNoAtomic) {
const char* hlo_string = R"(
HloModule TensorFlowScatterV1
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1, unique_indices=true
}
)";
CompileAndVerifyIr(hlo_string, R"(
CHECK-NOT: store atomic{{.*}}unordered, align 4
)");
}
} // namespace } // namespace
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true; option cc_enable_arenas = true;
// Serialization of HloInstruction. // Serialization of HloInstruction.
// Next ID: 69 // Next ID: 70
message HloInstructionProto { message HloInstructionProto {
reserved 10; reserved 10;
reserved "parameter_name"; reserved "parameter_name";
@ -237,6 +237,10 @@ message HloInstructionProto {
// Frontend attributes to pass to the XLA backend. // Frontend attributes to pass to the XLA backend.
xla.FrontendAttributes frontend_attributes = 68; xla.FrontendAttributes frontend_attributes = 68;
// Specifies if all elements updated are guaranteed to be unique by
// the caller.
bool unique_indices = 69;
} }
// Serialization of HloComputation. // Serialization of HloComputation.

View File

@ -563,9 +563,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
auto scatter_dimension_numbers = auto scatter_dimension_numbers =
absl::make_unique<ScatterDimensionNumbers>( absl::make_unique<ScatterDimensionNumbers>(
proto.scatter_dimension_numbers()); proto.scatter_dimension_numbers());
instruction = CreateScatter(shape, operands(0), operands(1), operands(2), instruction =
CreateScatter(shape, operands(0), operands(1), operands(2),
computations(0), *scatter_dimension_numbers, computations(0), *scatter_dimension_numbers,
proto.indices_are_sorted()); proto.indices_are_sorted(), proto.unique_indices());
break; break;
} }
case HloOpcode::kIota: case HloOpcode::kIota:
@ -1392,11 +1393,11 @@ bool HloInstruction::HasSideEffect() const {
const Shape& shape, HloInstruction* operand, const Shape& shape, HloInstruction* operand,
HloInstruction* scatter_indices, HloInstruction* updates, HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation, HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers, const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
bool indices_are_sorted) { bool unique_indices) {
return absl::make_unique<HloScatterInstruction>( return absl::make_unique<HloScatterInstruction>(
shape, operand, scatter_indices, updates, update_computation, shape, operand, scatter_indices, updates, update_computation,
scatter_dim_numbers, indices_are_sorted); scatter_dim_numbers, indices_are_sorted, unique_indices);
} }
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(

View File

@ -801,7 +801,7 @@ class HloInstruction {
HloInstruction* scatter_indices, HloInstruction* updates, HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation, HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers, const ScatterDimensionNumbers& scatter_dim_numbers,
bool indices_are_sorted); bool indices_are_sorted, bool unique_indices);
// Creates a kDomain instruction which delimits an HLO domain which have // Creates a kDomain instruction which delimits an HLO domain which have
// the provided user and operand side metadata. // the provided user and operand side metadata.
@ -1629,6 +1629,9 @@ class HloInstruction {
LOG(FATAL) << "Unimplemented method."; LOG(FATAL) << "Unimplemented method.";
} }
// Returns the unique_indices field.
virtual bool unique_indices() const { LOG(FATAL) << "Unimplemented method."; }
// Returns data on the dimension numbers used for a convolution operation, // Returns data on the dimension numbers used for a convolution operation,
// which may be a kConvolution instruction or a kCustomCall that implements a // which may be a kConvolution instruction or a kCustomCall that implements a
// convolution. // convolution.

View File

@ -1529,7 +1529,8 @@ TEST_F(HloInstructionTest, StringifyScatter) {
/*inserted_window_dims=*/{}, /*inserted_window_dims=*/{},
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/2), /*index_vector_dim=*/2),
/*indices_are_sorted=*/false)); /*indices_are_sorted=*/false,
/*unique_indices=*/false));
module->AddEntryComputation(builder.Build()); module->AddEntryComputation(builder.Build());
EXPECT_EQ( EXPECT_EQ(

View File

@ -2493,9 +2493,11 @@ HloScatterInstruction::HloScatterInstruction(
const Shape& shape, HloInstruction* operand, const Shape& shape, HloInstruction* operand,
HloInstruction* scatter_indices, HloInstruction* updates, HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation, HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted) const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
bool unique_indices)
: HloInstruction(HloOpcode::kScatter, shape), : HloInstruction(HloOpcode::kScatter, shape),
indices_are_sorted_(indices_are_sorted) { indices_are_sorted_(indices_are_sorted),
unique_indices_(unique_indices) {
AppendOperand(operand); AppendOperand(operand);
AppendOperand(scatter_indices); AppendOperand(scatter_indices);
AppendOperand(updates); AppendOperand(updates);
@ -2550,6 +2552,7 @@ HloInstructionProto HloScatterInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto(); HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers(); *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
proto.set_indices_are_sorted(indices_are_sorted()); proto.set_indices_are_sorted(indices_are_sorted());
proto.set_unique_indices(unique_indices());
return proto; return proto;
} }
@ -2560,6 +2563,9 @@ std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
if (indices_are_sorted()) { if (indices_are_sorted()) {
attrs.push_back("indices_are_sorted=true"); attrs.push_back("indices_are_sorted=true");
} }
if (unique_indices()) {
attrs.push_back("unique_indices=true");
}
return attrs; return attrs;
} }
@ -2572,7 +2578,8 @@ bool HloScatterInstruction::IdenticalSlowPath(
scatter_dimension_numbers(), scatter_dimension_numbers(),
casted_other.scatter_dimension_numbers()) && casted_other.scatter_dimension_numbers()) &&
eq_computations(to_apply(), casted_other.to_apply()) && eq_computations(to_apply(), casted_other.to_apply()) &&
indices_are_sorted() == casted_other.indices_are_sorted(); indices_are_sorted() == casted_other.indices_are_sorted() &&
unique_indices() == casted_other.unique_indices();
} }
std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl( std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
@ -2581,7 +2588,7 @@ std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
CHECK_EQ(new_operands.size(), 3); CHECK_EQ(new_operands.size(), 3);
return absl::make_unique<HloScatterInstruction>( return absl::make_unique<HloScatterInstruction>(
shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
scatter_dimension_numbers(), indices_are_sorted()); scatter_dimension_numbers(), indices_are_sorted(), unique_indices());
} }
HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension) HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)

View File

@ -1453,7 +1453,7 @@ class HloScatterInstruction : public HloInstruction {
HloInstruction* scatter_indices, HloInstruction* updates, HloInstruction* scatter_indices, HloInstruction* updates,
HloComputation* update_computation, HloComputation* update_computation,
const ScatterDimensionNumbers& scatter_dim_numbers, const ScatterDimensionNumbers& scatter_dim_numbers,
bool indices_are_sorted); bool indices_are_sorted, bool unique_indices);
const ScatterDimensionNumbers& scatter_dimension_numbers() const { const ScatterDimensionNumbers& scatter_dimension_numbers() const {
CHECK(scatter_dimension_numbers_ != nullptr); CHECK(scatter_dimension_numbers_ != nullptr);
return *scatter_dimension_numbers_; return *scatter_dimension_numbers_;
@ -1462,6 +1462,7 @@ class HloScatterInstruction : public HloInstruction {
void set_indices_are_sorted(bool indices_are_sorted) { void set_indices_are_sorted(bool indices_are_sorted) {
indices_are_sorted_ = indices_are_sorted; indices_are_sorted_ = indices_are_sorted;
} }
bool unique_indices() const override { return unique_indices_; }
// Returns a serialized representation of this instruction. // Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override; HloInstructionProto ToProto() const override;
@ -1489,6 +1490,7 @@ class HloScatterInstruction : public HloInstruction {
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_; std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
bool indices_are_sorted_; bool indices_are_sorted_;
bool unique_indices_;
}; };
class HloIotaInstruction : public HloInstruction { class HloIotaInstruction : public HloInstruction {

View File

@ -1726,6 +1726,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
optional<bool> indices_are_sorted = false; optional<bool> indices_are_sorted = false;
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool, attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
&indices_are_sorted}; &indices_are_sorted};
optional<bool> unique_indices = false;
attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
&unique_indices};
if (!ParseOperands(&operands, /*expected_size=*/3) || if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) { !ParseAttributes(attrs)) {
@ -1742,7 +1745,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
instruction = builder->AddInstruction(HloInstruction::CreateScatter( instruction = builder->AddInstruction(HloInstruction::CreateScatter(
shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1], shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
/*updates=*/operands[2], *update_computation, dim_numbers, /*updates=*/operands[2], *update_computation, dim_numbers,
indices_are_sorted.value())); indices_are_sorted.value(), unique_indices.value()));
break; break;
} }
case HloOpcode::kDomain: { case HloOpcode::kDomain: {

View File

@ -934,6 +934,25 @@ ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7
ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3
} }
)"
},
{
"UniqueIndicesScatter",
R"(HloModule StringifyUniqueIndicesScatter
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
%scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
%updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, unique_indices=true, to_apply=%add_F32.v3
}
)" )"
}, },
{ {

View File

@ -225,6 +225,36 @@ ENTRY main {
RunTest(hlo_text, &operand, &scatter_indices, &updates); RunTest(hlo_text, &operand, &scatter_indices, &updates);
} }
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add_UniqueIndices) {
const string hlo_text = R"(
HloModule TensorFlowScatter_Add
add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT add = s32[] add(s32[] lhs, s32[] rhs)
}
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter = s32[3,3] scatter(operand, indices, updates),
to_apply=add_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1,
unique_indices=true
}
)";
Literal operand =
LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
const string hlo_text = R"( const string hlo_text = R"(
HloModule TensorFlowScatter_Mul HloModule TensorFlowScatter_Mul