change the new parameter name.
This commit is contained in:
parent
c77c165dcd
commit
2a99d7e898
@ -1734,12 +1734,12 @@ XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices,
|
|||||||
const XlaOp& updates,
|
const XlaOp& updates,
|
||||||
const XlaComputation& update_computation,
|
const XlaComputation& update_computation,
|
||||||
const ScatterDimensionNumbers& dimension_numbers,
|
const ScatterDimensionNumbers& dimension_numbers,
|
||||||
bool indices_are_sorted, bool use_atomic) {
|
bool indices_are_sorted, bool unique_indices) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
HloInstructionProto instr;
|
||||||
instr.set_indices_are_sorted(indices_are_sorted);
|
instr.set_indices_are_sorted(indices_are_sorted);
|
||||||
|
|
||||||
instr.set_use_atomic(use_atomic);
|
instr.set_unique_indices(unique_indices);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
|
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
|
||||||
TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape,
|
TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape,
|
||||||
@ -3380,10 +3380,10 @@ XlaOp Gather(const XlaOp input, const XlaOp start_indices,
|
|||||||
XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
|
XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
|
||||||
const XlaOp updates, const XlaComputation& update_computation,
|
const XlaOp updates, const XlaComputation& update_computation,
|
||||||
const ScatterDimensionNumbers& dimension_numbers,
|
const ScatterDimensionNumbers& dimension_numbers,
|
||||||
bool indices_are_sorted, bool use_atomic) {
|
bool indices_are_sorted, bool unique_indices) {
|
||||||
return input.builder()->Scatter(input, scatter_indices, updates,
|
return input.builder()->Scatter(input, scatter_indices, updates,
|
||||||
update_computation, dimension_numbers,
|
update_computation, dimension_numbers,
|
||||||
indices_are_sorted, use_atomic);
|
indices_are_sorted, unique_indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Send(const XlaOp operand, const ChannelHandle& handle) {
|
void Send(const XlaOp operand, const ChannelHandle& handle) {
|
||||||
|
@ -592,7 +592,7 @@ class XlaBuilder {
|
|||||||
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
|
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
|
||||||
const XlaOp& updates, const XlaComputation& update_computation,
|
const XlaOp& updates, const XlaComputation& update_computation,
|
||||||
const ScatterDimensionNumbers& dimension_numbers,
|
const ScatterDimensionNumbers& dimension_numbers,
|
||||||
bool indices_are_sorted = false, bool use_atomic = true);
|
bool indices_are_sorted = false, bool unique_indices = false);
|
||||||
|
|
||||||
void Send(const XlaOp& operand, const ChannelHandle& handle);
|
void Send(const XlaOp& operand, const ChannelHandle& handle);
|
||||||
XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
|
XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
|
||||||
@ -1010,7 +1010,7 @@ class XlaBuilder {
|
|||||||
friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||||
const XlaComputation& update_computation,
|
const XlaComputation& update_computation,
|
||||||
const ScatterDimensionNumbers& dimension_numbers,
|
const ScatterDimensionNumbers& dimension_numbers,
|
||||||
bool indices_are_sorted, bool use_atomic);
|
bool indices_are_sorted, bool unique_indices);
|
||||||
friend void Send(XlaOp operand, const ChannelHandle& handle);
|
friend void Send(XlaOp operand, const ChannelHandle& handle);
|
||||||
friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
|
friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
|
||||||
const ChannelHandle& handle);
|
const ChannelHandle& handle);
|
||||||
@ -1869,7 +1869,7 @@ XlaOp Gather(XlaOp input, XlaOp start_indices,
|
|||||||
XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||||
const XlaComputation& update_computation,
|
const XlaComputation& update_computation,
|
||||||
const ScatterDimensionNumbers& dimension_numbers,
|
const ScatterDimensionNumbers& dimension_numbers,
|
||||||
bool indices_are_sorted = false, bool use_atomic = true);
|
bool indices_are_sorted = false, bool unique_indices = false);
|
||||||
|
|
||||||
// Enqueues a Send node onto the computation for device-to-device
|
// Enqueues a Send node onto the computation for device-to-device
|
||||||
// communication. This operation sends the given operand to
|
// communication. This operation sends the given operand to
|
||||||
|
@ -1382,10 +1382,9 @@ For a more intuitive description, see the "Informal Description" section below.
|
|||||||
| `indices_are_sorted` | `bool` | Whether the indices are |
|
| `indices_are_sorted` | `bool` | Whether the indices are |
|
||||||
: : : guaranteed to be sorted by :
|
: : : guaranteed to be sorted by :
|
||||||
: : : the caller. :
|
: : : the caller. :
|
||||||
| `use_atomic` | `bool` | Whether to use atomic |
|
| `unique_indices` | `bool` | Whether the indices are |
|
||||||
: : : operation for the update. To :
|
: : : guaranteed to be unique by :
|
||||||
: : : use only when the the caller :
|
: : : the caller :
|
||||||
: : : guarante no duplicate indices :
|
|
||||||
|
|
||||||
For convenience, we label dimensions in the output array not in `offset_dims`
|
For convenience, we label dimensions in the output array not in `offset_dims`
|
||||||
as `batch_dims`.
|
as `batch_dims`.
|
||||||
@ -1454,8 +1453,10 @@ If `indices_are_sorted` is set to true then XLA can assume that `start_indices`
|
|||||||
are sorted (in ascending `start_index_map` order) by the user. If they are not
|
are sorted (in ascending `start_index_map` order) by the user. If they are not
|
||||||
then the semantics is implementation defined.
|
then the semantics is implementation defined.
|
||||||
|
|
||||||
If `use_atomic` is set to false then XLA will not use atomic operation. This
|
If `unique_indices` is set to true then XLA can assume that all
|
||||||
is only safe when there is no duplicate indices.
|
element scattered to are unique. So XLA could use non-atomic
|
||||||
|
operation. If they are not, then the semantics is implementation
|
||||||
|
defined.
|
||||||
|
|
||||||
### Informal Description and Examples
|
### Informal Description and Examples
|
||||||
|
|
||||||
|
@ -1545,11 +1545,11 @@ class ComputationBuilder(object):
|
|||||||
update_computation,
|
update_computation,
|
||||||
dimension_numbers,
|
dimension_numbers,
|
||||||
indices_are_sorted=False,
|
indices_are_sorted=False,
|
||||||
use_atomic=True):
|
unique_indices=False):
|
||||||
"""Enqueues a Scatter operation onto the computation."""
|
"""Enqueues a Scatter operation onto the computation."""
|
||||||
return ops.Scatter(a, scatter_indices, updates,
|
return ops.Scatter(a, scatter_indices, updates,
|
||||||
update_computation.computation, dimension_numbers,
|
update_computation.computation, dimension_numbers,
|
||||||
indices_are_sorted, use_atomic)
|
indices_are_sorted, unique_indices)
|
||||||
|
|
||||||
def Fft(self, operand, fft_type, fft_lengths):
|
def Fft(self, operand, fft_type, fft_lengths):
|
||||||
"""Enqueues a FFT operation onto the computation."""
|
"""Enqueues a FFT operation onto the computation."""
|
||||||
|
@ -404,7 +404,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
|||||||
scatter_fused_emitter.GetGenerator(root->operand(1)),
|
scatter_fused_emitter.GetGenerator(root->operand(1)),
|
||||||
/*updates_gen=*/
|
/*updates_gen=*/
|
||||||
scatter_fused_emitter.GetGenerator(root->operand(2)),
|
scatter_fused_emitter.GetGenerator(root->operand(2)),
|
||||||
root->use_atomic()));
|
root->unique_indices()));
|
||||||
}
|
}
|
||||||
AddThunkToThunkSequence(
|
AddThunkToThunkSequence(
|
||||||
absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
|
absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
|
||||||
@ -842,7 +842,7 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
|
|||||||
return GetIrArray(*updates, *scatter)
|
return GetIrArray(*updates, *scatter)
|
||||||
.EmitReadArrayElement(index, &b_, "update");
|
.EmitReadArrayElement(index, &b_, "update");
|
||||||
},
|
},
|
||||||
scatter->use_atomic()));
|
scatter->unique_indices()));
|
||||||
|
|
||||||
// Elide the sequential thunk if there's no copy.
|
// Elide the sequential thunk if there's no copy.
|
||||||
if (thunks.size() == 1) {
|
if (thunks.size() == 1) {
|
||||||
@ -859,7 +859,7 @@ Status IrEmitterUnnested::EmitScatter(
|
|||||||
Thunk* thunk, HloInstruction* scatter,
|
Thunk* thunk, HloInstruction* scatter,
|
||||||
const llvm_ir::ElementGenerator& scatter_indices_gen,
|
const llvm_ir::ElementGenerator& scatter_indices_gen,
|
||||||
const llvm_ir::ElementGenerator& updates_gen,
|
const llvm_ir::ElementGenerator& updates_gen,
|
||||||
bool use_atomic) {
|
bool unique_indices) {
|
||||||
const HloInstruction* operand = scatter->operand(0);
|
const HloInstruction* operand = scatter->operand(0);
|
||||||
const HloInstruction* scatter_indices = scatter->operand(1);
|
const HloInstruction* scatter_indices = scatter->operand(1);
|
||||||
const HloInstruction* updates = scatter->operand(2);
|
const HloInstruction* updates = scatter->operand(2);
|
||||||
@ -969,7 +969,7 @@ Status IrEmitterUnnested::EmitScatter(
|
|||||||
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
|
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
|
||||||
Store(input_ir_value, input_address);
|
Store(input_ir_value, input_address);
|
||||||
|
|
||||||
if (scatter->use_atomic()) {
|
if (!scatter->unique_indices()) {
|
||||||
return EmitAtomicOperationForNestedComputation(
|
return EmitAtomicOperationForNestedComputation(
|
||||||
*scatter->to_apply(), output_address, input_address);
|
*scatter->to_apply(), output_address, input_address);
|
||||||
} else {
|
} else {
|
||||||
|
@ -226,16 +226,16 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
|
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
|
||||||
// the process. `scatter` may be fused, scatter indices are taken from
|
// the process. `scatter` may be fused, scatter indices are taken from
|
||||||
// `scatter_indices_gen`, updates from`updates_gen`. The output buffer is
|
// `scatter_indices_gen`, updates from`updates_gen`. The output buffer is
|
||||||
// expected to have the operand values in it already. If use_atomic
|
// expected to have the operand values in it already. If unique_indices
|
||||||
// is true, we will use an atomic update. Using false for use_atomic
|
// is false, we will use an atomic update. Using false for unique_indices
|
||||||
// is safe only when it is guaranteed that there are no duplicate
|
// is safe only when it is guaranteed that there are no duplicate
|
||||||
// indices.
|
// indices.
|
||||||
// When using use_atomi=false, it is the caller responsibility to
|
// When using unique_indices=true, it is the caller responsibility to
|
||||||
// ensure there is overlap.
|
// ensure there is no overlap.
|
||||||
Status EmitScatter(Thunk* thunk, HloInstruction* scatter,
|
Status EmitScatter(Thunk* thunk, HloInstruction* scatter,
|
||||||
const llvm_ir::ElementGenerator& scatter_indices_gen,
|
const llvm_ir::ElementGenerator& scatter_indices_gen,
|
||||||
const llvm_ir::ElementGenerator& updates_gen,
|
const llvm_ir::ElementGenerator& updates_gen,
|
||||||
bool use_atomic);
|
bool unique_indices);
|
||||||
|
|
||||||
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
|
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
|
||||||
// for the hlo instruction.
|
// for the hlo instruction.
|
||||||
|
@ -238,10 +238,9 @@ message HloInstructionProto {
|
|||||||
// Frontend attributes to pass to the XLA backend.
|
// Frontend attributes to pass to the XLA backend.
|
||||||
xla.FrontendAttributes frontend_attributes = 68;
|
xla.FrontendAttributes frontend_attributes = 68;
|
||||||
|
|
||||||
// Specifies if the scatter should use atomic operation or not. If
|
// Specifies if all elements updated are guaranteed to be unique by
|
||||||
// there is duplicate index, then it should be true to compute the
|
// the caller.
|
||||||
// right answer.
|
bool unique_indices = 69;
|
||||||
bool use_atomic = 69;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialization of HloComputation.
|
// Serialization of HloComputation.
|
||||||
|
@ -565,7 +565,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
proto.scatter_dimension_numbers());
|
proto.scatter_dimension_numbers());
|
||||||
instruction = CreateScatter(shape, operands(0), operands(1), operands(2),
|
instruction = CreateScatter(shape, operands(0), operands(1), operands(2),
|
||||||
computations(0), *scatter_dimension_numbers,
|
computations(0), *scatter_dimension_numbers,
|
||||||
proto.indices_are_sorted(), proto.use_atomic());
|
proto.indices_are_sorted(), proto.unique_indices());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kIota:
|
case HloOpcode::kIota:
|
||||||
@ -1393,10 +1393,10 @@ bool HloInstruction::HasSideEffect() const {
|
|||||||
HloInstruction* scatter_indices, HloInstruction* updates,
|
HloInstruction* scatter_indices, HloInstruction* updates,
|
||||||
HloComputation* update_computation,
|
HloComputation* update_computation,
|
||||||
const ScatterDimensionNumbers& scatter_dim_numbers,
|
const ScatterDimensionNumbers& scatter_dim_numbers,
|
||||||
bool indices_are_sorted, bool use_atomic) {
|
bool indices_are_sorted, bool unique_indices) {
|
||||||
return absl::make_unique<HloScatterInstruction>(
|
return absl::make_unique<HloScatterInstruction>(
|
||||||
shape, operand, scatter_indices, updates, update_computation,
|
shape, operand, scatter_indices, updates, update_computation,
|
||||||
scatter_dim_numbers, indices_are_sorted, use_atomic);
|
scatter_dim_numbers, indices_are_sorted, unique_indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
|
||||||
|
@ -801,7 +801,7 @@ class HloInstruction {
|
|||||||
HloInstruction* scatter_indices, HloInstruction* updates,
|
HloInstruction* scatter_indices, HloInstruction* updates,
|
||||||
HloComputation* update_computation,
|
HloComputation* update_computation,
|
||||||
const ScatterDimensionNumbers& scatter_dim_numbers,
|
const ScatterDimensionNumbers& scatter_dim_numbers,
|
||||||
bool indices_are_sorted, bool use_atomic);
|
bool indices_are_sorted, bool unique_indices);
|
||||||
|
|
||||||
// Creates a kDomain instruction which delimits an HLO domain which have
|
// Creates a kDomain instruction which delimits an HLO domain which have
|
||||||
// the provided user and operand side metadata.
|
// the provided user and operand side metadata.
|
||||||
@ -1629,8 +1629,8 @@ class HloInstruction {
|
|||||||
LOG(FATAL) << "Unimplemented method.";
|
LOG(FATAL) << "Unimplemented method.";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the use_atomic field.
|
// Returns the unique_indices field.
|
||||||
virtual bool use_atomic() const {
|
virtual bool unique_indices() const {
|
||||||
LOG(FATAL) << "Unimplemented method.";
|
LOG(FATAL) << "Unimplemented method.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1530,7 +1530,7 @@ TEST_F(HloInstructionTest, StringifyScatter) {
|
|||||||
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
||||||
/*index_vector_dim=*/2),
|
/*index_vector_dim=*/2),
|
||||||
/*indices_are_sorted=*/false,
|
/*indices_are_sorted=*/false,
|
||||||
/*use_atomic=*/true));
|
/*unique_indices=*/false));
|
||||||
module->AddEntryComputation(builder.Build());
|
module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
|
@ -2494,9 +2494,9 @@ HloScatterInstruction::HloScatterInstruction(
|
|||||||
HloInstruction* scatter_indices, HloInstruction* updates,
|
HloInstruction* scatter_indices, HloInstruction* updates,
|
||||||
HloComputation* update_computation,
|
HloComputation* update_computation,
|
||||||
const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
|
const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
|
||||||
bool use_atomic)
|
bool unique_indices)
|
||||||
: HloInstruction(HloOpcode::kScatter, shape),
|
: HloInstruction(HloOpcode::kScatter, shape),
|
||||||
indices_are_sorted_(indices_are_sorted), use_atomic_(use_atomic) {
|
indices_are_sorted_(indices_are_sorted), unique_indices_(unique_indices) {
|
||||||
AppendOperand(operand);
|
AppendOperand(operand);
|
||||||
AppendOperand(scatter_indices);
|
AppendOperand(scatter_indices);
|
||||||
AppendOperand(updates);
|
AppendOperand(updates);
|
||||||
@ -2551,7 +2551,7 @@ HloInstructionProto HloScatterInstruction::ToProto() const {
|
|||||||
HloInstructionProto proto = HloInstruction::ToProto();
|
HloInstructionProto proto = HloInstruction::ToProto();
|
||||||
*proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
|
*proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
|
||||||
proto.set_indices_are_sorted(indices_are_sorted());
|
proto.set_indices_are_sorted(indices_are_sorted());
|
||||||
proto.set_use_atomic(use_atomic());
|
proto.set_unique_indices(unique_indices());
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2562,8 +2562,8 @@ std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
|
|||||||
if (indices_are_sorted()) {
|
if (indices_are_sorted()) {
|
||||||
attrs.push_back("indices_are_sorted=true");
|
attrs.push_back("indices_are_sorted=true");
|
||||||
}
|
}
|
||||||
if (!use_atomic()) {
|
if (unique_indices()) {
|
||||||
attrs.push_back("use_atomic=false");
|
attrs.push_back("unique_indices=true");
|
||||||
}
|
}
|
||||||
return attrs;
|
return attrs;
|
||||||
}
|
}
|
||||||
@ -2578,7 +2578,7 @@ bool HloScatterInstruction::IdenticalSlowPath(
|
|||||||
casted_other.scatter_dimension_numbers()) &&
|
casted_other.scatter_dimension_numbers()) &&
|
||||||
eq_computations(to_apply(), casted_other.to_apply()) &&
|
eq_computations(to_apply(), casted_other.to_apply()) &&
|
||||||
indices_are_sorted() == casted_other.indices_are_sorted() &&
|
indices_are_sorted() == casted_other.indices_are_sorted() &&
|
||||||
use_atomic() == casted_other.use_atomic();
|
unique_indices() == casted_other.unique_indices();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
|
std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
|
||||||
@ -2587,7 +2587,7 @@ std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
|
|||||||
CHECK_EQ(new_operands.size(), 3);
|
CHECK_EQ(new_operands.size(), 3);
|
||||||
return absl::make_unique<HloScatterInstruction>(
|
return absl::make_unique<HloScatterInstruction>(
|
||||||
shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
|
shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
|
||||||
scatter_dimension_numbers(), indices_are_sorted(), use_atomic());
|
scatter_dimension_numbers(), indices_are_sorted(), unique_indices());
|
||||||
}
|
}
|
||||||
|
|
||||||
HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
|
HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
|
||||||
|
@ -1453,7 +1453,7 @@ class HloScatterInstruction : public HloInstruction {
|
|||||||
HloInstruction* scatter_indices, HloInstruction* updates,
|
HloInstruction* scatter_indices, HloInstruction* updates,
|
||||||
HloComputation* update_computation,
|
HloComputation* update_computation,
|
||||||
const ScatterDimensionNumbers& scatter_dim_numbers,
|
const ScatterDimensionNumbers& scatter_dim_numbers,
|
||||||
bool indices_are_sorted, bool use_atomic);
|
bool indices_are_sorted, bool unique_indices);
|
||||||
const ScatterDimensionNumbers& scatter_dimension_numbers() const {
|
const ScatterDimensionNumbers& scatter_dimension_numbers() const {
|
||||||
CHECK(scatter_dimension_numbers_ != nullptr);
|
CHECK(scatter_dimension_numbers_ != nullptr);
|
||||||
return *scatter_dimension_numbers_;
|
return *scatter_dimension_numbers_;
|
||||||
@ -1462,7 +1462,7 @@ class HloScatterInstruction : public HloInstruction {
|
|||||||
void set_indices_are_sorted(bool indices_are_sorted) {
|
void set_indices_are_sorted(bool indices_are_sorted) {
|
||||||
indices_are_sorted_ = indices_are_sorted;
|
indices_are_sorted_ = indices_are_sorted;
|
||||||
}
|
}
|
||||||
bool use_atomic() const override { return use_atomic_; }
|
bool unique_indices() const override { return unique_indices_; }
|
||||||
// Returns a serialized representation of this instruction.
|
// Returns a serialized representation of this instruction.
|
||||||
HloInstructionProto ToProto() const override;
|
HloInstructionProto ToProto() const override;
|
||||||
|
|
||||||
@ -1490,7 +1490,7 @@ class HloScatterInstruction : public HloInstruction {
|
|||||||
|
|
||||||
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
|
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
|
||||||
bool indices_are_sorted_;
|
bool indices_are_sorted_;
|
||||||
bool use_atomic_;
|
bool unique_indices_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class HloIotaInstruction : public HloInstruction {
|
class HloIotaInstruction : public HloInstruction {
|
||||||
|
@ -1726,9 +1726,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
optional<bool> indices_are_sorted = false;
|
optional<bool> indices_are_sorted = false;
|
||||||
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
|
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
|
||||||
&indices_are_sorted};
|
&indices_are_sorted};
|
||||||
optional<bool> use_atomic = true;
|
optional<bool> unique_indices = false;
|
||||||
attrs["use_atomic"] = {/*required=*/false, AttrTy::kBool,
|
attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool,
|
||||||
&use_atomic};
|
&unique_indices};
|
||||||
|
|
||||||
if (!ParseOperands(&operands, /*expected_size=*/3) ||
|
if (!ParseOperands(&operands, /*expected_size=*/3) ||
|
||||||
!ParseAttributes(attrs)) {
|
!ParseAttributes(attrs)) {
|
||||||
@ -1745,7 +1745,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
instruction = builder->AddInstruction(HloInstruction::CreateScatter(
|
instruction = builder->AddInstruction(HloInstruction::CreateScatter(
|
||||||
shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
|
shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
|
||||||
/*updates=*/operands[2], *update_computation, dim_numbers,
|
/*updates=*/operands[2], *update_computation, dim_numbers,
|
||||||
indices_are_sorted.value(), use_atomic.value()));
|
indices_are_sorted.value(), unique_indices.value()));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HloOpcode::kDomain: {
|
case HloOpcode::kDomain: {
|
||||||
|
@ -937,8 +937,8 @@ ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7
|
|||||||
)"
|
)"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"AtomicScatter",
|
"UniqueIndicesScatter",
|
||||||
R"(HloModule StringifyAtomicScatter
|
R"(HloModule StringifyUniqueIndicesScatter
|
||||||
|
|
||||||
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
|
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||||
%lhs = f32[] parameter(0)
|
%lhs = f32[] parameter(0)
|
||||||
@ -950,7 +950,7 @@ ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7
|
|||||||
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
|
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
|
||||||
%scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
|
%scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
|
||||||
%updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
|
%updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
|
||||||
ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, use_atomic=false, to_apply=%add_F32.v3
|
ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, unique_indices=true, to_apply=%add_F32.v3
|
||||||
}
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
|
@ -225,7 +225,7 @@ ENTRY main {
|
|||||||
RunTest(hlo_text, &operand, &scatter_indices, &updates);
|
RunTest(hlo_text, &operand, &scatter_indices, &updates);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add_NoAtomic) {
|
XLA_TEST_F(ScatterTest, TensorFlowScatter_Add_UniqueIndices) {
|
||||||
const string hlo_text = R"(
|
const string hlo_text = R"(
|
||||||
HloModule TensorFlowScatter_Add
|
HloModule TensorFlowScatter_Add
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ ENTRY main {
|
|||||||
inserted_window_dims={0},
|
inserted_window_dims={0},
|
||||||
scatter_dims_to_operand_dims={0},
|
scatter_dims_to_operand_dims={0},
|
||||||
index_vector_dim=1,
|
index_vector_dim=1,
|
||||||
use_atomic=false
|
unique_indices=true
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
Literal operand =
|
Literal operand =
|
||||||
|
Loading…
Reference in New Issue
Block a user