[XLA] Allow CustomCall to specify aliasing buffers
PiperOrigin-RevId: 333464460 Change-Id: I666c7a7fade6c9925cd8d2c7dd4673b897a38034
This commit is contained in:
parent
55b546de0e
commit
1b5b9486b0
tensorflow/compiler
@ -135,16 +135,12 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
bool has_side_effect) {
|
||||
if (operand_shapes_with_layout.has_value())
|
||||
return Unimplemented(
|
||||
"CustomCall doesn't support operands shapes with layout");
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
TF_RET_CHECK(output_operand_aliasing.empty())
|
||||
<< "MLIR CustomCallOp does not support output_operand_aliasing yet";
|
||||
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
|
||||
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
|
||||
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),
|
||||
|
@ -135,9 +135,7 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) override;
|
||||
bool has_side_effect) override;
|
||||
|
||||
StatusOr<XlaOp> ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
|
@ -1707,9 +1707,7 @@ XlaOp XlaBuilder::CustomCall(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
bool has_side_effect) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
if (absl::StartsWith(call_target_name, "$")) {
|
||||
return InvalidArgument(
|
||||
@ -1741,8 +1739,7 @@ XlaOp XlaBuilder::CustomCall(
|
||||
}
|
||||
}
|
||||
return CustomCallInternal(call_target_name, operands, shape, opaque,
|
||||
operand_shapes_with_layout, has_side_effect,
|
||||
output_operand_aliasing);
|
||||
operand_shapes_with_layout, has_side_effect);
|
||||
});
|
||||
}
|
||||
|
||||
@ -1750,9 +1747,7 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
bool has_side_effect) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_custom_call_target(call_target_name);
|
||||
@ -1764,16 +1759,6 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
||||
}
|
||||
}
|
||||
instr.set_custom_call_has_side_effect(has_side_effect);
|
||||
for (const auto& pair : output_operand_aliasing) {
|
||||
auto aliasing = instr.add_custom_call_output_operand_aliasing();
|
||||
aliasing->set_operand_index(pair.second.first);
|
||||
for (int64 index : pair.second.second) {
|
||||
aliasing->add_operand_shape_index(index);
|
||||
}
|
||||
for (int64 index : pair.first) {
|
||||
aliasing->add_output_shape_index(index);
|
||||
}
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
||||
}
|
||||
|
||||
@ -1781,9 +1766,7 @@ XlaOp XlaBuilder::CustomCall(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation, const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
bool has_side_effect) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
if (absl::StartsWith(call_target_name, "$")) {
|
||||
@ -1821,16 +1804,6 @@ XlaOp XlaBuilder::CustomCall(
|
||||
}
|
||||
}
|
||||
AddCalledComputation(computation, &instr);
|
||||
for (const auto& pair : output_operand_aliasing) {
|
||||
auto aliasing = instr.add_custom_call_output_operand_aliasing();
|
||||
aliasing->set_operand_index(pair.second.first);
|
||||
for (int64 index : pair.second.second) {
|
||||
aliasing->add_operand_shape_index(index);
|
||||
}
|
||||
for (int64 index : pair.first) {
|
||||
aliasing->add_output_shape_index(index);
|
||||
}
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
||||
});
|
||||
}
|
||||
@ -3888,39 +3861,31 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
|
||||
return builder->Call(computation, operands);
|
||||
}
|
||||
|
||||
XlaOp CustomCall(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
const string& opaque, bool has_side_effect) {
|
||||
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
||||
/*operand_shapes_with_layout=*/absl::nullopt,
|
||||
has_side_effect, output_operand_aliasing);
|
||||
has_side_effect);
|
||||
}
|
||||
|
||||
XlaOp CustomCallWithComputation(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||
const Shape& shape, const string& opaque, bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
return builder->CustomCall(call_target_name, operands, computation, shape,
|
||||
opaque,
|
||||
/*operand_shapes_with_layout=*/absl::nullopt,
|
||||
has_side_effect, output_operand_aliasing);
|
||||
XlaOp CustomCallWithComputation(XlaBuilder* builder,
|
||||
const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation,
|
||||
const Shape& shape, const string& opaque,
|
||||
bool has_side_effect) {
|
||||
return builder->CustomCall(
|
||||
call_target_name, operands, computation, shape, opaque,
|
||||
/*operand_shapes_with_layout=*/absl::nullopt, has_side_effect);
|
||||
}
|
||||
|
||||
XlaOp CustomCallWithLayout(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing) {
|
||||
XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
absl::Span<const Shape> operand_shapes_with_layout,
|
||||
const string& opaque, bool has_side_effect) {
|
||||
return builder->CustomCall(call_target_name, operands, shape, opaque,
|
||||
operand_shapes_with_layout, has_side_effect,
|
||||
output_operand_aliasing);
|
||||
operand_shapes_with_layout, has_side_effect);
|
||||
}
|
||||
|
||||
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
|
||||
|
@ -593,9 +593,7 @@ class XlaBuilder {
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape_with_layout, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
bool has_side_effect);
|
||||
|
||||
// Internal version of CustomCall without computation that doesn't do op
|
||||
// specific error handling and expects arguments to be legal. CustomCall
|
||||
@ -604,18 +602,14 @@ class XlaBuilder {
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape_with_layout, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
bool has_side_effect);
|
||||
|
||||
XlaOp CustomCall(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation, const Shape& shape_with_layout,
|
||||
const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
bool has_side_effect);
|
||||
|
||||
XlaOp Reduce(XlaOp operand, XlaOp init_value,
|
||||
const XlaComputation& computation,
|
||||
@ -1064,25 +1058,18 @@ class XlaBuilder {
|
||||
const string& outfeed_config);
|
||||
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
|
||||
absl::Span<const XlaOp> operands);
|
||||
friend XlaOp CustomCall(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
const string& opaque, bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
const string& opaque, bool has_side_effect);
|
||||
friend XlaOp CustomCallWithComputation(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||
const Shape& shape, const string& opaque, bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
const Shape& shape, const string& opaque, bool has_side_effect);
|
||||
friend XlaOp CustomCallWithLayout(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
||||
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
|
||||
bool has_side_effect,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing);
|
||||
bool has_side_effect);
|
||||
friend XlaOp Complex(XlaOp real, XlaOp imag,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Conj(XlaOp operand);
|
||||
@ -1818,39 +1805,30 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
|
||||
// backend, a call instruction is emitted which targets a symbol with the name
|
||||
// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
|
||||
// but |call_target_name| should be short as it may be used in labels. |opaque|
|
||||
// can encode arbitrarily large amounts of information. |has_side_effect|
|
||||
// specifies whether the instruction can have side effects.
|
||||
// |output_operand_aliasing| specifies a list of output/operand buffer pairs
|
||||
// that alias each other, where the output buffer is represented as a
|
||||
// ShapeIndex, and the operand buffer is represented as the operand index and
|
||||
// the ShapeIndex.
|
||||
XlaOp CustomCall(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
const string& opaque = "", bool has_side_effect = false,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing = {});
|
||||
// can encode arbitrarily large amounts of information.
|
||||
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape,
|
||||
const string& opaque = "", bool has_side_effect = false);
|
||||
|
||||
// Overload which constructs a custom call that applies an Xla computation.
|
||||
XlaOp CustomCallWithComputation(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const XlaComputation& computation,
|
||||
const Shape& shape, const string& opaque = "", bool has_side_effect = false,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing = {});
|
||||
XlaOp CustomCallWithComputation(XlaBuilder* builder,
|
||||
const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation,
|
||||
const Shape& shape, const string& opaque = "",
|
||||
bool has_side_effect = false);
|
||||
|
||||
// Overload which constructs a custom call with fixed layouts. The operands will
|
||||
// have the layouts specified by |operand_shapes_with_layout| when provided to
|
||||
// external code, and the external code is expected to produce a result with the
|
||||
// layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
|
||||
// and |operand_shapes_with_layout| must have layouts.
|
||||
XlaOp CustomCallWithLayout(
|
||||
XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
|
||||
absl::Span<const Shape> operand_shapes_with_layout,
|
||||
const string& opaque = "", bool has_side_effect = false,
|
||||
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_operand_aliasing = {});
|
||||
XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands,
|
||||
const Shape& shape_with_layout,
|
||||
absl::Span<const Shape> operand_shapes_with_layout,
|
||||
const string& opaque = "",
|
||||
bool has_side_effect = false);
|
||||
|
||||
// The following methods enqueue element-wise binary arithmetic operations
|
||||
// onto the computation. The shapes of the operands have to match unless one
|
||||
|
@ -3492,8 +3492,6 @@ cc_library(
|
||||
hdrs = ["memory_space_assignment_utils.h"],
|
||||
deps = [
|
||||
":heap_simulator",
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Serialization of HloInstruction.
|
||||
// Next ID: 75
|
||||
// Next ID: 74
|
||||
message HloInstructionProto {
|
||||
reserved 10;
|
||||
reserved "parameter_name";
|
||||
@ -232,11 +232,6 @@ message HloInstructionProto {
|
||||
// kCustomCall.
|
||||
bool custom_call_has_side_effect = 65;
|
||||
|
||||
// A list of CustomCallOutputOperandAliasing pairs that specifies aliasing
|
||||
// buffers between output and operands for kCustomCall.
|
||||
repeated xla.CustomCallOutputOperandAliasing
|
||||
custom_call_output_operand_aliasing = 74;
|
||||
|
||||
// The delta value for kRngGetAndUpdateState.
|
||||
int64 delta = 66;
|
||||
|
||||
|
@ -432,23 +432,6 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateCustomCallValueSet(
|
||||
HloInstruction* custom_call) {
|
||||
CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall);
|
||||
bool changed = false;
|
||||
for (const auto& aliasing : Cast<HloCustomCallInstruction>(custom_call)
|
||||
->output_to_operand_aliasing()) {
|
||||
const HloValueSet& operand_value_set = GetValueSet(
|
||||
custom_call->operand(aliasing.second.first), aliasing.second.second);
|
||||
HloValueSet& value_set = GetValueSet(custom_call, aliasing.first);
|
||||
if (value_set != operand_value_set) {
|
||||
value_set = operand_value_set;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
|
||||
CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
|
||||
bool changed = false;
|
||||
@ -774,8 +757,6 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
||||
return UpdateAddDependencyValueSet(instruction);
|
||||
case HloOpcode::kBitcast:
|
||||
return UpdateBitcastValueSet(instruction);
|
||||
case HloOpcode::kCustomCall:
|
||||
return UpdateCustomCallValueSet(instruction);
|
||||
case HloOpcode::kSetDimensionSize:
|
||||
return UpdateSetDimensionSizeValueSet(instruction);
|
||||
case HloOpcode::kDomain:
|
||||
@ -1037,22 +1018,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
define_value_at(/*index=*/{1});
|
||||
define_value_at(/*index=*/{2});
|
||||
break;
|
||||
case HloOpcode::kCustomCall: {
|
||||
absl::flat_hash_set<ShapeIndex> aliasing_indices;
|
||||
for (const auto& aliasing :
|
||||
Cast<HloCustomCallInstruction>(instruction)
|
||||
->output_to_operand_aliasing()) {
|
||||
aliasing_indices.insert(aliasing.first);
|
||||
}
|
||||
ShapeUtil::ForEachSubshape(
|
||||
instruction->shape(),
|
||||
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
if (!aliasing_indices.contains(index)) {
|
||||
define_value_at(index);
|
||||
}
|
||||
});
|
||||
break;
|
||||
}
|
||||
default:
|
||||
define_all_values();
|
||||
break;
|
||||
|
@ -216,7 +216,6 @@ class HloDataflowAnalysis {
|
||||
bool UpdateCallValueSet(HloInstruction* call);
|
||||
bool UpdateConditionalValueSet(HloInstruction* conditional);
|
||||
bool UpdateCopyValueSet(HloInstruction* copy);
|
||||
bool UpdateCustomCallValueSet(HloInstruction* custom_call);
|
||||
bool UpdateDomainValueSet(HloInstruction* domain);
|
||||
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
|
||||
bool UpdateParameterValueSet(HloInstruction* parameter);
|
||||
|
@ -568,19 +568,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
|
||||
custom_call_instr->set_custom_call_has_side_effect(
|
||||
proto.custom_call_has_side_effect());
|
||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_to_operand_aliasing;
|
||||
for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
|
||||
output_to_operand_aliasing.emplace_back(
|
||||
ShapeIndex(aliasing.output_shape_index().begin(),
|
||||
aliasing.output_shape_index().end()),
|
||||
std::pair<int64, ShapeIndex>{
|
||||
aliasing.operand_index(),
|
||||
ShapeIndex(aliasing.operand_shape_index().begin(),
|
||||
aliasing.operand_shape_index().end())});
|
||||
}
|
||||
custom_call_instr->set_output_to_operand_aliasing(
|
||||
std::move(output_to_operand_aliasing));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kPad:
|
||||
|
@ -2395,16 +2395,6 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
|
||||
}
|
||||
}
|
||||
proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
||||
for (const auto& pair : output_to_operand_aliasing_) {
|
||||
auto aliasing = proto.add_custom_call_output_operand_aliasing();
|
||||
aliasing->set_operand_index(pair.second.first);
|
||||
for (int64 index : pair.first) {
|
||||
aliasing->add_output_shape_index(index);
|
||||
}
|
||||
for (int64 index : pair.second.second) {
|
||||
aliasing->add_operand_shape_index(index);
|
||||
}
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
@ -2442,16 +2432,6 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
|
||||
if (custom_call_has_side_effect_) {
|
||||
extra.push_back("custom_call_has_side_effect=true");
|
||||
}
|
||||
if (!output_to_operand_aliasing_.empty()) {
|
||||
std::vector<string> pair_strings;
|
||||
for (const auto& pair : output_to_operand_aliasing_) {
|
||||
pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
|
||||
pair.second.first, ", ",
|
||||
pair.second.second.ToString(), ")"));
|
||||
}
|
||||
extra.push_back(StrCat("output_to_operand_aliasing={",
|
||||
StrJoin(pair_strings, ", "), "}"));
|
||||
}
|
||||
return extra;
|
||||
}
|
||||
|
||||
@ -2495,10 +2475,6 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
|
||||
casted_other.custom_call_has_side_effect()) {
|
||||
return false;
|
||||
}
|
||||
if (output_to_operand_aliasing_ !=
|
||||
casted_other.output_to_operand_aliasing()) {
|
||||
return false;
|
||||
}
|
||||
// Note: backend_config comparison is done in Identical, which is the
|
||||
// intended/exposed way to compare computations, and so not repeated here.
|
||||
return custom_call_target_ == casted_other.custom_call_target_;
|
||||
@ -2523,7 +2499,6 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
|
||||
cloned->set_feature_group_count(feature_group_count_);
|
||||
cloned->set_batch_group_count(batch_group_count_);
|
||||
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
|
||||
cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
|
||||
return std::move(cloned);
|
||||
}
|
||||
|
||||
|
@ -1430,20 +1430,6 @@ class HloCustomCallInstruction : public HloInstruction {
|
||||
CHECK(layout_constrained());
|
||||
return operand_shapes_with_layout_;
|
||||
}
|
||||
// Gets a list of output/operand buffer pairs that alias each other, where the
|
||||
// output buffer is represented as a ShapeIndex, and the operand buffer is
|
||||
// represented as the operand index and the ShapeIndex. By default this list
|
||||
// is empty.
|
||||
const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>&
|
||||
output_to_operand_aliasing() const {
|
||||
return output_to_operand_aliasing_;
|
||||
}
|
||||
// Sets the list of output/operand buffer pairs that alias each other.
|
||||
void set_output_to_operand_aliasing(
|
||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
aliasing) {
|
||||
output_to_operand_aliasing_ = std::move(aliasing);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
@ -1472,10 +1458,6 @@ class HloCustomCallInstruction : public HloInstruction {
|
||||
std::vector<Shape> operand_shapes_with_layout_;
|
||||
// Whether this custom call has a side-effect.
|
||||
bool custom_call_has_side_effect_;
|
||||
// A list of output/operand buffer pairs that alias each other. See comment of
|
||||
// output_to_operand_aliasing().
|
||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
output_to_operand_aliasing_;
|
||||
};
|
||||
|
||||
class HloPadInstruction : public HloInstruction {
|
||||
|
@ -212,7 +212,6 @@ class HloParserImpl : public HloParser {
|
||||
kEnum,
|
||||
kRandomAlgorithm,
|
||||
kAliasing,
|
||||
kInstructionAliasing,
|
||||
};
|
||||
|
||||
struct AttrConfig {
|
||||
@ -347,12 +346,6 @@ class HloParserImpl : public HloParser {
|
||||
// fails.
|
||||
bool ParseAliasing(AliasingData* data);
|
||||
|
||||
// Parses the per-instruction aliasing information from string `s`, returns
|
||||
// `false` if it fails.
|
||||
bool ParseInstructionOutputOperandAliasing(
|
||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
|
||||
aliasing_output_operand_pairs);
|
||||
|
||||
bool ParseShapeIndex(ShapeIndex* out);
|
||||
|
||||
// Returns true if the current token is the beginning of a shape.
|
||||
@ -605,58 +598,6 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParserImpl::ParseInstructionOutputOperandAliasing(
|
||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
|
||||
aliasing_output_operand_pairs) {
|
||||
if (!ParseToken(
|
||||
TokKind::kLbrace,
|
||||
"Expects '{' at the start of instruction aliasing description")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
while (lexer_.GetKind() != TokKind::kRbrace) {
|
||||
ShapeIndex out;
|
||||
if (!ParseShapeIndex(&out)) {
|
||||
return false;
|
||||
}
|
||||
std::string errmsg =
|
||||
"Expected format: <output_shape_index>: (<operand_index>, "
|
||||
"<operand_shape_index>)";
|
||||
if (!ParseToken(TokKind::kColon, errmsg)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ParseToken(TokKind::kLparen, errmsg)) {
|
||||
return false;
|
||||
}
|
||||
int64 operand_index;
|
||||
ParseInt64(&operand_index);
|
||||
if (!ParseToken(TokKind::kComma, errmsg)) {
|
||||
return false;
|
||||
}
|
||||
ShapeIndex operand_shape_index;
|
||||
if (!ParseShapeIndex(&operand_shape_index)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
aliasing_output_operand_pairs->emplace_back(
|
||||
out, std::pair<int64, ShapeIndex>{operand_index, operand_shape_index});
|
||||
if (!ParseToken(TokKind::kRparen, errmsg)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!EatIfPresent(TokKind::kComma)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!ParseToken(
|
||||
TokKind::kRbrace,
|
||||
"Expects '}' at the end of instruction aliasing description")) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// ::= 'HloModule' name computations
|
||||
bool HloParserImpl::ParseHloModule(HloModule* module) {
|
||||
if (lexer_.GetKind() != TokKind::kw_HloModule) {
|
||||
@ -1836,8 +1777,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
optional<std::vector<Shape>> operand_layout_constraints;
|
||||
optional<bool> custom_call_has_side_effect;
|
||||
optional<HloComputation*> to_apply;
|
||||
optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
|
||||
output_to_operand_aliasing;
|
||||
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
|
||||
&custom_call_target};
|
||||
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
|
||||
@ -1853,9 +1792,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
&custom_call_has_side_effect};
|
||||
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
|
||||
&to_apply};
|
||||
attrs["output_to_operand_aliasing"] = {/*required=*/false,
|
||||
AttrTy::kInstructionAliasing,
|
||||
&output_to_operand_aliasing};
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
@ -1925,10 +1861,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
custom_call_instr->set_custom_call_has_side_effect(
|
||||
*custom_call_has_side_effect);
|
||||
}
|
||||
if (output_to_operand_aliasing.has_value()) {
|
||||
custom_call_instr->set_output_to_operand_aliasing(
|
||||
std::move(*output_to_operand_aliasing));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kDot: {
|
||||
@ -3291,19 +3223,6 @@ bool HloParserImpl::ParseAttributeHelper(
|
||||
->emplace(aliasing_data);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kInstructionAliasing: {
|
||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
|
||||
aliasing_output_operand_pairs;
|
||||
if (!ParseInstructionOutputOperandAliasing(
|
||||
&aliasing_output_operand_pairs)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<
|
||||
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>*>(
|
||||
attr_out_ptr)
|
||||
->emplace(std::move(aliasing_output_operand_pairs));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}();
|
||||
if (!success) {
|
||||
|
@ -991,19 +991,6 @@ ENTRY %CustomCallWithHasSideEffect (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4])
|
||||
ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// CustomCallWithAliasing
|
||||
{
|
||||
"CustomCallWithAliasing",
|
||||
R"(HloModule CustomCallWithAliasing
|
||||
|
||||
ENTRY %CustomCallWithAliasing (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[123,4], f32[2,2], f32[1,2,3]) {
|
||||
%p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
|
||||
%p1 = f32[123,4]{0,1} parameter(1)
|
||||
ROOT %custom-call = (f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", output_to_operand_aliasing={{0}: (1, {}), {1}: (0, {0})}
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// Parse c64 literal
|
||||
|
@ -801,28 +801,6 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
|
||||
}
|
||||
}
|
||||
for (const auto& pair : custom_call->output_to_operand_aliasing()) {
|
||||
TF_RET_CHECK(pair.second.first < custom_call->operand_count())
|
||||
<< "Invalid aliasing operand index.";
|
||||
TF_RET_CHECK(ShapeUtil::IndexIsValid(
|
||||
custom_call->operand(pair.second.first)->shape(), pair.second.second))
|
||||
<< "Invalid aliasing operand shape index.";
|
||||
TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first))
|
||||
<< "Invalid aliasing output shape index.";
|
||||
const Shape& output_subshape =
|
||||
ShapeUtil::GetSubshape(custom_call->shape(), pair.first);
|
||||
const Shape& operand_subshape = ShapeUtil::GetSubshape(
|
||||
custom_call->operand(pair.second.first)->shape(), pair.second.second);
|
||||
if (layout_sensitive_) {
|
||||
TF_RET_CHECK(operand_subshape == output_subshape)
|
||||
<< "Different aliasing shapes: " << operand_subshape.ToString()
|
||||
<< " vs " << output_subshape.ToString();
|
||||
} else {
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape))
|
||||
<< "Different aliasing shapes: " << operand_subshape.ToString()
|
||||
<< " vs " << output_subshape.ToString();
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -15,9 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
|
||||
@ -90,17 +87,6 @@ bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (auto* custom_call =
|
||||
DynCast<HloCustomCallInstruction>(position.instruction)) {
|
||||
for (const auto& pair : custom_call->output_to_operand_aliasing()) {
|
||||
if (position.index == pair.first) {
|
||||
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||
<< " in default mem because it is a custom-call output that "
|
||||
"aliases an operand buffer.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -691,11 +691,3 @@ message WhileLoopBackendConfig {
|
||||
// unknown-trip-count.
|
||||
KnownTripCount known_trip_count = 1;
|
||||
}
|
||||
|
||||
// Specifies a pair of output/operand buffers for kCustomCall that alias each
|
||||
// other.
|
||||
message CustomCallOutputOperandAliasing {
|
||||
repeated int64 output_shape_index = 1;
|
||||
int64 operand_index = 2;
|
||||
repeated int64 operand_shape_index = 3;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user