[XLA] Allow CustomCall to specify aliasing buffers

PiperOrigin-RevId: 333464460
Change-Id: I666c7a7fade6c9925cd8d2c7dd4673b897a38034
This commit is contained in:
A. Unique TensorFlower 2020-09-24 01:10:46 -07:00 committed by TensorFlower Gardener
parent 55b546de0e
commit 1b5b9486b0
16 changed files with 49 additions and 349 deletions

View File

@ -135,16 +135,12 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
bool has_side_effect) {
if (operand_shapes_with_layout.has_value())
return Unimplemented(
"CustomCall doesn't support operands shapes with layout");
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
TF_RET_CHECK(output_operand_aliasing.empty())
<< "MLIR CustomCallOp does not support output_operand_aliasing yet";
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
/*has_side_effect=*/builder_.getBoolAttr(has_side_effect),

View File

@ -135,9 +135,7 @@ class MlirHloBuilder : public XlaBuilder {
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) override;
bool has_side_effect) override;
StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,

View File

@ -1707,9 +1707,7 @@ XlaOp XlaBuilder::CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
bool has_side_effect) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
@ -1741,8 +1739,7 @@ XlaOp XlaBuilder::CustomCall(
}
}
return CustomCallInternal(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
output_operand_aliasing);
operand_shapes_with_layout, has_side_effect);
});
}
@ -1750,9 +1747,7 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
bool has_side_effect) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
@ -1764,16 +1759,6 @@ StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
}
}
instr.set_custom_call_has_side_effect(has_side_effect);
for (const auto& pair : output_operand_aliasing) {
auto aliasing = instr.add_custom_call_output_operand_aliasing();
aliasing->set_operand_index(pair.second.first);
for (int64 index : pair.second.second) {
aliasing->add_operand_shape_index(index);
}
for (int64 index : pair.first) {
aliasing->add_output_shape_index(index);
}
}
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
}
@ -1781,9 +1766,7 @@ XlaOp XlaBuilder::CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
bool has_side_effect) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@ -1821,16 +1804,6 @@ XlaOp XlaBuilder::CustomCall(
}
}
AddCalledComputation(computation, &instr);
for (const auto& pair : output_operand_aliasing) {
auto aliasing = instr.add_custom_call_output_operand_aliasing();
aliasing->set_operand_index(pair.second.first);
for (int64 index : pair.second.second) {
aliasing->add_operand_shape_index(index);
}
for (int64 index : pair.first) {
aliasing->add_output_shape_index(index);
}
}
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
@ -3888,39 +3861,31 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
return builder->Call(computation, operands);
}
XlaOp CustomCall(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque, bool has_side_effect) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
has_side_effect, output_operand_aliasing);
has_side_effect);
}
XlaOp CustomCallWithComputation(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
return builder->CustomCall(call_target_name, operands, computation, shape,
opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
has_side_effect, output_operand_aliasing);
XlaOp CustomCallWithComputation(XlaBuilder* builder,
const string& call_target_name,
absl::Span<const XlaOp> operands,
const XlaComputation& computation,
const Shape& shape, const string& opaque,
bool has_side_effect) {
return builder->CustomCall(
call_target_name, operands, computation, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt, has_side_effect);
}
XlaOp CustomCallWithLayout(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing) {
XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
absl::Span<const Shape> operand_shapes_with_layout,
const string& opaque, bool has_side_effect) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
output_operand_aliasing);
operand_shapes_with_layout, has_side_effect);
}
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,

View File

@ -593,9 +593,7 @@ class XlaBuilder {
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
bool has_side_effect);
// Internal version of CustomCall without computation that doesn't do op
// specific error handling and expects arguments to be legal. CustomCall
@ -604,18 +602,14 @@ class XlaBuilder {
const string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
bool has_side_effect);
XlaOp CustomCall(
const string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape_with_layout,
const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
bool has_side_effect);
XlaOp Reduce(XlaOp operand, XlaOp init_value,
const XlaComputation& computation,
@ -1064,25 +1058,18 @@ class XlaBuilder {
const string& outfeed_config);
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
friend XlaOp CustomCall(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque, bool has_side_effect);
friend XlaOp CustomCallWithComputation(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
const Shape& shape, const string& opaque, bool has_side_effect);
friend XlaOp CustomCallWithLayout(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing);
bool has_side_effect);
friend XlaOp Complex(XlaOp real, XlaOp imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(XlaOp operand);
@ -1818,39 +1805,30 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
// backend, a call instruction is emitted which targets a symbol with the name
// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
// but |call_target_name| should be short as it may be used in labels. |opaque|
// can encode arbitrarily large amounts of information. |has_side_effect|
// specifies whether the instruction can have side effects.
// |output_operand_aliasing| specifies a list of output/operand buffer pairs
// that alias each other, where the output buffer is represented as a
// ShapeIndex, and the operand buffer is represented as the operand index and
// the ShapeIndex.
XlaOp CustomCall(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing = {});
// can encode arbitrarily large amounts of information.
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque = "", bool has_side_effect = false);
// Overload which constructs a custom call that applies an Xla computation.
XlaOp CustomCallWithComputation(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing = {});
XlaOp CustomCallWithComputation(XlaBuilder* builder,
const string& call_target_name,
absl::Span<const XlaOp> operands,
const XlaComputation& computation,
const Shape& shape, const string& opaque = "",
bool has_side_effect = false);
// Overload which constructs a custom call with fixed layouts. The operands will
// have the layouts specified by |operand_shapes_with_layout| when provided to
// external code, and the external code is expected to produce a result with the
// layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
// and |operand_shapes_with_layout| must have layouts.
XlaOp CustomCallWithLayout(
XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout,
const string& opaque = "", bool has_side_effect = false,
absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_operand_aliasing = {});
XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands,
const Shape& shape_with_layout,
absl::Span<const Shape> operand_shapes_with_layout,
const string& opaque = "",
bool has_side_effect = false);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one

View File

@ -3492,8 +3492,6 @@ cc_library(
hdrs = ["memory_space_assignment_utils.h"],
deps = [
":heap_simulator",
":hlo",
":hlo_casting_utils",
],
)

View File

@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
// Next ID: 75
// Next ID: 74
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@ -232,11 +232,6 @@ message HloInstructionProto {
// kCustomCall.
bool custom_call_has_side_effect = 65;
// A list of CustomCallOutputOperandAliasing pairs that specifies aliasing
// buffers between output and operands for kCustomCall.
repeated xla.CustomCallOutputOperandAliasing
custom_call_output_operand_aliasing = 74;
// The delta value for kRngGetAndUpdateState.
int64 delta = 66;

View File

@ -432,23 +432,6 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
return changed;
}
bool HloDataflowAnalysis::UpdateCustomCallValueSet(
HloInstruction* custom_call) {
CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall);
bool changed = false;
for (const auto& aliasing : Cast<HloCustomCallInstruction>(custom_call)
->output_to_operand_aliasing()) {
const HloValueSet& operand_value_set = GetValueSet(
custom_call->operand(aliasing.second.first), aliasing.second.second);
HloValueSet& value_set = GetValueSet(custom_call, aliasing.first);
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
return changed;
}
bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
bool changed = false;
@ -774,8 +757,6 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
return UpdateAddDependencyValueSet(instruction);
case HloOpcode::kBitcast:
return UpdateBitcastValueSet(instruction);
case HloOpcode::kCustomCall:
return UpdateCustomCallValueSet(instruction);
case HloOpcode::kSetDimensionSize:
return UpdateSetDimensionSizeValueSet(instruction);
case HloOpcode::kDomain:
@ -1037,22 +1018,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
define_value_at(/*index=*/{1});
define_value_at(/*index=*/{2});
break;
case HloOpcode::kCustomCall: {
absl::flat_hash_set<ShapeIndex> aliasing_indices;
for (const auto& aliasing :
Cast<HloCustomCallInstruction>(instruction)
->output_to_operand_aliasing()) {
aliasing_indices.insert(aliasing.first);
}
ShapeUtil::ForEachSubshape(
instruction->shape(),
[&](const Shape& /*subshape*/, const ShapeIndex& index) {
if (!aliasing_indices.contains(index)) {
define_value_at(index);
}
});
break;
}
default:
define_all_values();
break;

View File

@ -216,7 +216,6 @@ class HloDataflowAnalysis {
bool UpdateCallValueSet(HloInstruction* call);
bool UpdateConditionalValueSet(HloInstruction* conditional);
bool UpdateCopyValueSet(HloInstruction* copy);
bool UpdateCustomCallValueSet(HloInstruction* custom_call);
bool UpdateDomainValueSet(HloInstruction* domain);
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
bool UpdateParameterValueSet(HloInstruction* parameter);

View File

@ -568,19 +568,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
custom_call_instr->set_custom_call_has_side_effect(
proto.custom_call_has_side_effect());
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_to_operand_aliasing;
for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
output_to_operand_aliasing.emplace_back(
ShapeIndex(aliasing.output_shape_index().begin(),
aliasing.output_shape_index().end()),
std::pair<int64, ShapeIndex>{
aliasing.operand_index(),
ShapeIndex(aliasing.operand_shape_index().begin(),
aliasing.operand_shape_index().end())});
}
custom_call_instr->set_output_to_operand_aliasing(
std::move(output_to_operand_aliasing));
break;
}
case HloOpcode::kPad:

View File

@ -2395,16 +2395,6 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
}
}
proto.set_custom_call_has_side_effect(custom_call_has_side_effect_);
for (const auto& pair : output_to_operand_aliasing_) {
auto aliasing = proto.add_custom_call_output_operand_aliasing();
aliasing->set_operand_index(pair.second.first);
for (int64 index : pair.first) {
aliasing->add_output_shape_index(index);
}
for (int64 index : pair.second.second) {
aliasing->add_operand_shape_index(index);
}
}
return proto;
}
@ -2442,16 +2432,6 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
if (custom_call_has_side_effect_) {
extra.push_back("custom_call_has_side_effect=true");
}
if (!output_to_operand_aliasing_.empty()) {
std::vector<string> pair_strings;
for (const auto& pair : output_to_operand_aliasing_) {
pair_strings.push_back(StrCat(pair.first.ToString(), ": (",
pair.second.first, ", ",
pair.second.second.ToString(), ")"));
}
extra.push_back(StrCat("output_to_operand_aliasing={",
StrJoin(pair_strings, ", "), "}"));
}
return extra;
}
@ -2495,10 +2475,6 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
casted_other.custom_call_has_side_effect()) {
return false;
}
if (output_to_operand_aliasing_ !=
casted_other.output_to_operand_aliasing()) {
return false;
}
// Note: backend_config comparison is done in Identical, which is the
// intended/exposed way to compare computations, and so not repeated here.
return custom_call_target_ == casted_other.custom_call_target_;
@ -2523,7 +2499,6 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
cloned->set_feature_group_count(feature_group_count_);
cloned->set_batch_group_count(batch_group_count_);
cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_);
cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_);
return std::move(cloned);
}

View File

@ -1430,20 +1430,6 @@ class HloCustomCallInstruction : public HloInstruction {
CHECK(layout_constrained());
return operand_shapes_with_layout_;
}
// Gets a list of output/operand buffer pairs that alias each other, where the
// output buffer is represented as a ShapeIndex, and the operand buffer is
// represented as the operand index and the ShapeIndex. By default this list
// is empty.
const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>&
output_to_operand_aliasing() const {
return output_to_operand_aliasing_;
}
// Sets the list of output/operand buffer pairs that alias each other.
void set_output_to_operand_aliasing(
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
aliasing) {
output_to_operand_aliasing_ = std::move(aliasing);
}
private:
std::vector<string> ExtraAttributesToStringImpl(
@ -1472,10 +1458,6 @@ class HloCustomCallInstruction : public HloInstruction {
std::vector<Shape> operand_shapes_with_layout_;
// Whether this custom call has a side-effect.
bool custom_call_has_side_effect_;
// A list of output/operand buffer pairs that alias each other. See comment of
// output_to_operand_aliasing().
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
output_to_operand_aliasing_;
};
class HloPadInstruction : public HloInstruction {

View File

@ -212,7 +212,6 @@ class HloParserImpl : public HloParser {
kEnum,
kRandomAlgorithm,
kAliasing,
kInstructionAliasing,
};
struct AttrConfig {
@ -347,12 +346,6 @@ class HloParserImpl : public HloParser {
// fails.
bool ParseAliasing(AliasingData* data);
// Parses the per-instruction aliasing information from string `s`, returns
// `false` if it fails.
bool ParseInstructionOutputOperandAliasing(
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
aliasing_output_operand_pairs);
bool ParseShapeIndex(ShapeIndex* out);
// Returns true if the current token is the beginning of a shape.
@ -605,58 +598,6 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) {
return true;
}
bool HloParserImpl::ParseInstructionOutputOperandAliasing(
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>*
aliasing_output_operand_pairs) {
if (!ParseToken(
TokKind::kLbrace,
"Expects '{' at the start of instruction aliasing description")) {
return false;
}
while (lexer_.GetKind() != TokKind::kRbrace) {
ShapeIndex out;
if (!ParseShapeIndex(&out)) {
return false;
}
std::string errmsg =
"Expected format: <output_shape_index>: (<operand_index>, "
"<operand_shape_index>)";
if (!ParseToken(TokKind::kColon, errmsg)) {
return false;
}
if (!ParseToken(TokKind::kLparen, errmsg)) {
return false;
}
int64 operand_index;
ParseInt64(&operand_index);
if (!ParseToken(TokKind::kComma, errmsg)) {
return false;
}
ShapeIndex operand_shape_index;
if (!ParseShapeIndex(&operand_shape_index)) {
return false;
}
aliasing_output_operand_pairs->emplace_back(
out, std::pair<int64, ShapeIndex>{operand_index, operand_shape_index});
if (!ParseToken(TokKind::kRparen, errmsg)) {
return false;
}
if (!EatIfPresent(TokKind::kComma)) {
break;
}
}
if (!ParseToken(
TokKind::kRbrace,
"Expects '}' at the end of instruction aliasing description")) {
return false;
}
return true;
}
// ::= 'HloModule' name computations
bool HloParserImpl::ParseHloModule(HloModule* module) {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
@ -1836,8 +1777,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
optional<std::vector<Shape>> operand_layout_constraints;
optional<bool> custom_call_has_side_effect;
optional<HloComputation*> to_apply;
optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
output_to_operand_aliasing;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
@ -1853,9 +1792,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
&custom_call_has_side_effect};
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
&to_apply};
attrs["output_to_operand_aliasing"] = {/*required=*/false,
AttrTy::kInstructionAliasing,
&output_to_operand_aliasing};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@ -1925,10 +1861,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
custom_call_instr->set_custom_call_has_side_effect(
*custom_call_has_side_effect);
}
if (output_to_operand_aliasing.has_value()) {
custom_call_instr->set_output_to_operand_aliasing(
std::move(*output_to_operand_aliasing));
}
break;
}
case HloOpcode::kDot: {
@ -3291,19 +3223,6 @@ bool HloParserImpl::ParseAttributeHelper(
->emplace(aliasing_data);
return true;
}
case AttrTy::kInstructionAliasing: {
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
aliasing_output_operand_pairs;
if (!ParseInstructionOutputOperandAliasing(
&aliasing_output_operand_pairs)) {
return false;
}
static_cast<optional<
std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>*>(
attr_out_ptr)
->emplace(std::move(aliasing_output_operand_pairs));
return true;
}
}
}();
if (!success) {

View File

@ -991,19 +991,6 @@ ENTRY %CustomCallWithHasSideEffect (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4])
ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true
}
)"
},
// CustomCallWithAliasing
{
"CustomCallWithAliasing",
R"(HloModule CustomCallWithAliasing
ENTRY %CustomCallWithAliasing (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[123,4], f32[2,2], f32[1,2,3]) {
%p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
%p1 = f32[123,4]{0,1} parameter(1)
ROOT %custom-call = (f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", output_to_operand_aliasing={{0}: (1, {}), {1}: (0, {0})}
}
)"
},
// Parse c64 literal

View File

@ -801,28 +801,6 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
}
}
for (const auto& pair : custom_call->output_to_operand_aliasing()) {
TF_RET_CHECK(pair.second.first < custom_call->operand_count())
<< "Invalid aliasing operand index.";
TF_RET_CHECK(ShapeUtil::IndexIsValid(
custom_call->operand(pair.second.first)->shape(), pair.second.second))
<< "Invalid aliasing operand shape index.";
TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first))
<< "Invalid aliasing output shape index.";
const Shape& output_subshape =
ShapeUtil::GetSubshape(custom_call->shape(), pair.first);
const Shape& operand_subshape = ShapeUtil::GetSubshape(
custom_call->operand(pair.second.first)->shape(), pair.second.second);
if (layout_sensitive_) {
TF_RET_CHECK(operand_subshape == output_subshape)
<< "Different aliasing shapes: " << operand_subshape.ToString()
<< " vs " << output_subshape.ToString();
} else {
TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape))
<< "Different aliasing shapes: " << operand_subshape.ToString()
<< " vs " << output_subshape.ToString();
}
}
return Status::OK();
}

View File

@ -15,9 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
namespace xla {
bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
@ -90,17 +87,6 @@ bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
return false;
}
}
if (auto* custom_call =
DynCast<HloCustomCallInstruction>(position.instruction)) {
for (const auto& pair : custom_call->output_to_operand_aliasing()) {
if (position.index == pair.first) {
VLOG(4) << "Keeping value " << value->ToShortString()
<< " in default mem because it is a custom-call output that "
"aliases an operand buffer.";
return false;
}
}
}
}
return true;

View File

@ -691,11 +691,3 @@ message WhileLoopBackendConfig {
// unknown-trip-count.
KnownTripCount known_trip_count = 1;
}
// Specifies a pair of output/operand buffers for kCustomCall that alias each
// other.
message CustomCallOutputOperandAliasing {
repeated int64 output_shape_index = 1;
int64 operand_index = 2;
repeated int64 operand_shape_index = 3;
}