[XLA] add the Scatter option to not do atomic operation. This should keep the backward compatibility when this option didn't exist.
This commit is contained in:
parent
c4fc64c728
commit
27903a3978
@ -1734,11 +1734,13 @@ XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices,
|
||||
const XlaOp& updates,
|
||||
const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers,
|
||||
bool indices_are_sorted) {
|
||||
bool indices_are_sorted, bool use_atomic) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
instr.set_indices_are_sorted(indices_are_sorted);
|
||||
|
||||
instr.set_use_atomic(use_atomic);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape,
|
||||
GetShape(scatter_indices));
|
||||
@ -3378,10 +3380,10 @@ XlaOp Gather(const XlaOp input, const XlaOp start_indices,
|
||||
XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
|
||||
const XlaOp updates, const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers,
|
||||
bool indices_are_sorted) {
|
||||
bool indices_are_sorted, bool use_atomic) {
|
||||
return input.builder()->Scatter(input, scatter_indices, updates,
|
||||
update_computation, dimension_numbers,
|
||||
indices_are_sorted);
|
||||
indices_are_sorted, use_atomic);
|
||||
}
|
||||
|
||||
void Send(const XlaOp operand, const ChannelHandle& handle) {
|
||||
|
@ -592,7 +592,7 @@ class XlaBuilder {
|
||||
XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
|
||||
const XlaOp& updates, const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers,
|
||||
bool indices_are_sorted = false);
|
||||
bool indices_are_sorted = false, bool use_atomic = true);
|
||||
|
||||
void Send(const XlaOp& operand, const ChannelHandle& handle);
|
||||
XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
|
||||
@ -1010,7 +1010,7 @@ class XlaBuilder {
|
||||
friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||
const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers,
|
||||
bool indices_are_sorted);
|
||||
bool indices_are_sorted, bool use_atomic);
|
||||
friend void Send(XlaOp operand, const ChannelHandle& handle);
|
||||
friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
|
||||
const ChannelHandle& handle);
|
||||
@ -1869,7 +1869,7 @@ XlaOp Gather(XlaOp input, XlaOp start_indices,
|
||||
XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||
const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers,
|
||||
bool indices_are_sorted = false);
|
||||
bool indices_are_sorted = false, bool use_atomic = true);
|
||||
|
||||
// Enqueues a Send node onto the computation for device-to-device
|
||||
// communication. This operation sends the given operand to
|
||||
|
@ -1544,11 +1544,12 @@ class ComputationBuilder(object):
|
||||
updates,
|
||||
update_computation,
|
||||
dimension_numbers,
|
||||
indices_are_sorted=False):
|
||||
indices_are_sorted=False,
|
||||
use_atomic=True):
|
||||
"""Enqueues a Scatter operation onto the computation."""
|
||||
return ops.Scatter(a, scatter_indices, updates,
|
||||
update_computation.computation, dimension_numbers,
|
||||
indices_are_sorted)
|
||||
indices_are_sorted, use_atomic)
|
||||
|
||||
def Fft(self, operand, fft_type, fft_lengths):
|
||||
"""Enqueues a FFT operation onto the computation."""
|
||||
|
@ -403,7 +403,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
/*scatter_indices_gen=*/
|
||||
scatter_fused_emitter.GetGenerator(root->operand(1)),
|
||||
/*updates_gen=*/
|
||||
scatter_fused_emitter.GetGenerator(root->operand(2))));
|
||||
scatter_fused_emitter.GetGenerator(root->operand(2)),
|
||||
root->use_atomic()));
|
||||
}
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
|
||||
@ -840,7 +841,8 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
|
||||
[=](const IrArray::Index& index) {
|
||||
return GetIrArray(*updates, *scatter)
|
||||
.EmitReadArrayElement(index, &b_, "update");
|
||||
}));
|
||||
},
|
||||
scatter->use_atomic()));
|
||||
|
||||
// Elide the sequential thunk if there's no copy.
|
||||
if (thunks.size() == 1) {
|
||||
@ -856,7 +858,8 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
|
||||
Status IrEmitterUnnested::EmitScatter(
|
||||
Thunk* thunk, HloInstruction* scatter,
|
||||
const llvm_ir::ElementGenerator& scatter_indices_gen,
|
||||
const llvm_ir::ElementGenerator& updates_gen) {
|
||||
const llvm_ir::ElementGenerator& updates_gen,
|
||||
bool use_atomic) {
|
||||
const HloInstruction* operand = scatter->operand(0);
|
||||
const HloInstruction* scatter_indices = scatter->operand(1);
|
||||
const HloInstruction* updates = scatter->operand(2);
|
||||
@ -965,8 +968,14 @@ Status IrEmitterUnnested::EmitScatter(
|
||||
updates->shape().element_type(), module_));
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
|
||||
Store(input_ir_value, input_address);
|
||||
return EmitAtomicOperationForNestedComputation(
|
||||
*scatter->to_apply(), output_address, input_address);
|
||||
|
||||
if (scatter->use_atomic()) {
|
||||
return EmitAtomicOperationForNestedComputation(
|
||||
*scatter->to_apply(), output_address, input_address);
|
||||
} else {
|
||||
return EmitCallToNestedComputation(
|
||||
*scatter->to_apply(), {output_address, input_address}, output_address);
|
||||
}
|
||||
};
|
||||
|
||||
// Launch a kernel that reads every element in the updates tensor. We could
|
||||
|
@ -226,10 +226,14 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
|
||||
// the process. `scatter` may be fused, scatter indices are taken from
|
||||
// `scatter_indices_gen`, updates from`updates_gen`. The output buffer is
|
||||
// expected to have the operand values in it already.
|
||||
// expected to have the operand values in it already. If use_atomic
|
||||
// is true, we will use an atomic update. Using false for use_atomic
|
||||
// is safe only when it is guaranteed that there is no duplicate
|
||||
// indices.
|
||||
Status EmitScatter(Thunk* thunk, HloInstruction* scatter,
|
||||
const llvm_ir::ElementGenerator& scatter_indices_gen,
|
||||
const llvm_ir::ElementGenerator& updates_gen);
|
||||
const llvm_ir::ElementGenerator& updates_gen,
|
||||
bool use_atomic);
|
||||
|
||||
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
|
||||
// for the hlo instruction.
|
||||
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Serialization of HloInstruction.
|
||||
// Next ID: 69
|
||||
// Next ID: 70
|
||||
message HloInstructionProto {
|
||||
reserved 10;
|
||||
reserved "parameter_name";
|
||||
@ -237,6 +237,11 @@ message HloInstructionProto {
|
||||
|
||||
// Frontend attributes to pass to the XLA backend.
|
||||
xla.FrontendAttributes frontend_attributes = 68;
|
||||
|
||||
// Specifies if the scatter should use atomic operation or not. If
|
||||
// there is duplicate index, then it should be true to compute the
|
||||
// right answer.
|
||||
bool use_atomic = 69;
|
||||
}
|
||||
|
||||
// Serialization of HloComputation.
|
||||
|
@ -565,7 +565,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
proto.scatter_dimension_numbers());
|
||||
instruction = CreateScatter(shape, operands(0), operands(1), operands(2),
|
||||
computations(0), *scatter_dimension_numbers,
|
||||
proto.indices_are_sorted());
|
||||
proto.indices_are_sorted(), proto.use_atomic());
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kIota:
|
||||
@ -1393,10 +1393,10 @@ bool HloInstruction::HasSideEffect() const {
|
||||
HloInstruction* scatter_indices, HloInstruction* updates,
|
||||
HloComputation* update_computation,
|
||||
const ScatterDimensionNumbers& scatter_dim_numbers,
|
||||
bool indices_are_sorted) {
|
||||
bool indices_are_sorted, bool use_atomic) {
|
||||
return absl::make_unique<HloScatterInstruction>(
|
||||
shape, operand, scatter_indices, updates, update_computation,
|
||||
scatter_dim_numbers, indices_are_sorted);
|
||||
scatter_dim_numbers, indices_are_sorted, use_atomic);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
|
||||
|
@ -801,7 +801,7 @@ class HloInstruction {
|
||||
HloInstruction* scatter_indices, HloInstruction* updates,
|
||||
HloComputation* update_computation,
|
||||
const ScatterDimensionNumbers& scatter_dim_numbers,
|
||||
bool indices_are_sorted);
|
||||
bool indices_are_sorted, bool use_atomic);
|
||||
|
||||
// Creates a kDomain instruction which delimits an HLO domain which have
|
||||
// the provided user and operand side metadata.
|
||||
@ -1629,6 +1629,11 @@ class HloInstruction {
|
||||
LOG(FATAL) << "Unimplemented method.";
|
||||
}
|
||||
|
||||
// Returns the use_atomic field.
|
||||
virtual bool use_atomic() const {
|
||||
LOG(FATAL) << "Unimplemented method.";
|
||||
}
|
||||
|
||||
// Returns data on the dimension numbers used for a convolution operation,
|
||||
// which may be a kConvolution instruction or a kCustomCall that implements a
|
||||
// convolution.
|
||||
|
@ -2493,9 +2493,10 @@ HloScatterInstruction::HloScatterInstruction(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
HloInstruction* scatter_indices, HloInstruction* updates,
|
||||
HloComputation* update_computation,
|
||||
const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted)
|
||||
const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
|
||||
bool use_atomic)
|
||||
: HloInstruction(HloOpcode::kScatter, shape),
|
||||
indices_are_sorted_(indices_are_sorted) {
|
||||
indices_are_sorted_(indices_are_sorted), use_atomic_(use_atomic) {
|
||||
AppendOperand(operand);
|
||||
AppendOperand(scatter_indices);
|
||||
AppendOperand(updates);
|
||||
@ -2550,6 +2551,7 @@ HloInstructionProto HloScatterInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
*proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
|
||||
proto.set_indices_are_sorted(indices_are_sorted());
|
||||
proto.set_use_atomic(use_atomic());
|
||||
return proto;
|
||||
}
|
||||
|
||||
@ -2560,6 +2562,9 @@ std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
|
||||
if (indices_are_sorted()) {
|
||||
attrs.push_back("indices_are_sorted=true");
|
||||
}
|
||||
if (!use_atomic()) {
|
||||
attrs.push_back("use_atomic=false");
|
||||
}
|
||||
return attrs;
|
||||
}
|
||||
|
||||
@ -2572,7 +2577,8 @@ bool HloScatterInstruction::IdenticalSlowPath(
|
||||
scatter_dimension_numbers(),
|
||||
casted_other.scatter_dimension_numbers()) &&
|
||||
eq_computations(to_apply(), casted_other.to_apply()) &&
|
||||
indices_are_sorted() == casted_other.indices_are_sorted();
|
||||
indices_are_sorted() == casted_other.indices_are_sorted() &&
|
||||
use_atomic() == casted_other.use_atomic();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
|
||||
@ -2581,7 +2587,7 @@ std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
|
||||
CHECK_EQ(new_operands.size(), 3);
|
||||
return absl::make_unique<HloScatterInstruction>(
|
||||
shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
|
||||
scatter_dimension_numbers(), indices_are_sorted());
|
||||
scatter_dimension_numbers(), indices_are_sorted(), use_atomic());
|
||||
}
|
||||
|
||||
HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension)
|
||||
|
@ -1453,7 +1453,7 @@ class HloScatterInstruction : public HloInstruction {
|
||||
HloInstruction* scatter_indices, HloInstruction* updates,
|
||||
HloComputation* update_computation,
|
||||
const ScatterDimensionNumbers& scatter_dim_numbers,
|
||||
bool indices_are_sorted);
|
||||
bool indices_are_sorted, bool use_atomic);
|
||||
const ScatterDimensionNumbers& scatter_dimension_numbers() const {
|
||||
CHECK(scatter_dimension_numbers_ != nullptr);
|
||||
return *scatter_dimension_numbers_;
|
||||
@ -1462,6 +1462,7 @@ class HloScatterInstruction : public HloInstruction {
|
||||
void set_indices_are_sorted(bool indices_are_sorted) {
|
||||
indices_are_sorted_ = indices_are_sorted;
|
||||
}
|
||||
bool use_atomic() const override { return use_atomic_; }
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
@ -1489,6 +1490,7 @@ class HloScatterInstruction : public HloInstruction {
|
||||
|
||||
std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
|
||||
bool indices_are_sorted_;
|
||||
bool use_atomic_;
|
||||
};
|
||||
|
||||
class HloIotaInstruction : public HloInstruction {
|
||||
|
@ -1726,6 +1726,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
optional<bool> indices_are_sorted = false;
|
||||
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
|
||||
&indices_are_sorted};
|
||||
optional<bool> use_atomic = true;
|
||||
attrs["use_atomic"] = {/*required=*/false, AttrTy::kBool,
|
||||
&use_atomic};
|
||||
|
||||
if (!ParseOperands(&operands, /*expected_size=*/3) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
@ -1742,7 +1745,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateScatter(
|
||||
shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
|
||||
/*updates=*/operands[2], *update_computation, dim_numbers,
|
||||
indices_are_sorted.value()));
|
||||
indices_are_sorted.value(), use_atomic.value()));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kDomain: {
|
||||
|
Loading…
x
Reference in New Issue
Block a user