diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 3da893548bb..ac5e01a0abf 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -135,16 +135,12 @@ StatusOr MlirHloBuilder::CustomCallInternal( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing) { + bool has_side_effect) { if (operand_shapes_with_layout.has_value()) return Unimplemented( "CustomCall doesn't support operands shapes with layout"); TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); - TF_RET_CHECK(output_operand_aliasing.empty()) - << "MLIR CustomCallOp does not support output_operand_aliasing yet"; auto op = builder_.create( loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name), /*has_side_effect=*/builder_.getBoolAttr(has_side_effect), diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 59b4bc7b1e0..00b7aa4d0b0 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -135,9 +135,7 @@ class MlirHloBuilder : public XlaBuilder { const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing) override; + bool has_side_effect) override; StatusOr ReduceInternal( const Shape& shape, absl::Span all_operands, diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index c7bbf9f8486..3e2a4eb53a7 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1707,9 +1707,7 @@ XlaOp XlaBuilder::CustomCall( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing) { + bool has_side_effect) { return ReportErrorOrReturn([&]() -> StatusOr { if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( @@ -1741,8 +1739,7 @@ XlaOp XlaBuilder::CustomCall( } } return CustomCallInternal(call_target_name, operands, shape, opaque, - operand_shapes_with_layout, has_side_effect, - output_operand_aliasing); + operand_shapes_with_layout, has_side_effect); }); } @@ -1750,9 +1747,7 @@ StatusOr XlaBuilder::CustomCallInternal( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing) { + bool has_side_effect) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_custom_call_target(call_target_name); @@ -1764,16 +1759,6 @@ StatusOr XlaBuilder::CustomCallInternal( } } instr.set_custom_call_has_side_effect(has_side_effect); - for (const auto& pair : output_operand_aliasing) { - auto aliasing = instr.add_custom_call_output_operand_aliasing(); - aliasing->set_operand_index(pair.second.first); - for (int64 index : pair.second.second) { - aliasing->add_operand_shape_index(index); - } - for (int64 index : pair.first) { - aliasing->add_output_shape_index(index); - } - } return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); } @@ -1781,9 +1766,7 @@ XlaOp XlaBuilder::CustomCall( const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing) { + bool has_side_effect) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -1821,16 +1804,6 @@ XlaOp XlaBuilder::CustomCall( } } AddCalledComputation(computation, &instr); - for (const auto& pair : output_operand_aliasing) { - auto aliasing = instr.add_custom_call_output_operand_aliasing(); - aliasing->set_operand_index(pair.second.first); - for (int64 index : pair.second.second) { - aliasing->add_operand_shape_index(index); - } - for (int64 index : pair.first) { - aliasing->add_output_shape_index(index); - } - } return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); }); } @@ -3888,39 +3861,31 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, return builder->Call(computation, operands); } -XlaOp CustomCall( - XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape, const string& opaque, - bool has_side_effect, - absl::Span>> - output_operand_aliasing) { +XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + const string& opaque, bool has_side_effect) { return builder->CustomCall(call_target_name, operands, shape, opaque, /*operand_shapes_with_layout=*/absl::nullopt, - has_side_effect, output_operand_aliasing); + has_side_effect); } -XlaOp CustomCallWithComputation( - XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const XlaComputation& computation, - const Shape& shape, const string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing) { - return builder->CustomCall(call_target_name, operands, computation, shape, - opaque, - /*operand_shapes_with_layout=*/absl::nullopt, - has_side_effect, output_operand_aliasing); +XlaOp CustomCallWithComputation(XlaBuilder* builder, + const string& call_target_name, + absl::Span operands, + const XlaComputation& computation, + const Shape& shape, const string& opaque, + bool has_side_effect) { + return builder->CustomCall( + call_target_name, operands, computation, shape, opaque, + /*operand_shapes_with_layout=*/absl::nullopt, has_side_effect); } -XlaOp CustomCallWithLayout( - XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape, - absl::Span operand_shapes_with_layout, const string& opaque, - bool has_side_effect, - absl::Span>> - output_operand_aliasing) { +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, + const string& opaque, bool has_side_effect) { return builder->CustomCall(call_target_name, operands, shape, opaque, - operand_shapes_with_layout, has_side_effect, - output_operand_aliasing); + operand_shapes_with_layout, has_side_effect); } XlaOp Complex(const XlaOp lhs, const XlaOp rhs, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 55bcd86b493..cd9809c2a20 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -593,9 +593,7 @@ class XlaBuilder { const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing); + bool has_side_effect); // Internal version of CustomCall without computation that doesn't do op // specific error handling and expects arguments to be legal. CustomCall @@ -604,18 +602,14 @@ class XlaBuilder { const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing); + bool has_side_effect); XlaOp CustomCall( const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing); + bool has_side_effect); XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, @@ -1064,25 +1058,18 @@ class XlaBuilder { const string& outfeed_config); friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span operands); - friend XlaOp CustomCall( - XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape, - const string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing); + friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + const string& opaque, bool has_side_effect); friend XlaOp CustomCallWithComputation( XlaBuilder* builder, const string& call_target_name, absl::Span operands, const XlaComputation& computation, - const Shape& shape, const string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing); + const Shape& shape, const string& opaque, bool has_side_effect); friend XlaOp CustomCallWithLayout( XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape_with_layout, absl::Span operand_shapes_with_layout, const string& opaque, - bool has_side_effect, - absl::Span>> - output_operand_aliasing); + bool has_side_effect); friend XlaOp Complex(XlaOp real, XlaOp imag, absl::Span broadcast_dimensions); friend XlaOp Conj(XlaOp operand); @@ -1818,39 +1805,30 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, // backend, a call instruction is emitted which targets a symbol with the name // |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, // but |call_target_name| should be short as it may be used in labels. |opaque| -// can encode arbitrarily large amounts of information. |has_side_effect| -// specifies whether the instruction can have side effects. -// |output_operand_aliasing| specifies a list of output/operand buffer pairs -// that alias each other, where the output buffer is represented as a -// ShapeIndex, and the operand buffer is represented as the operand index and -// the ShapeIndex. -XlaOp CustomCall( - XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape, - const string& opaque = "", bool has_side_effect = false, - absl::Span>> - output_operand_aliasing = {}); +// can encode arbitrarily large amounts of information. +XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + const string& opaque = "", bool has_side_effect = false); // Overload which constructs a custom call that applies an Xla computation. -XlaOp CustomCallWithComputation( - XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const XlaComputation& computation, - const Shape& shape, const string& opaque = "", bool has_side_effect = false, - absl::Span>> - output_operand_aliasing = {}); +XlaOp CustomCallWithComputation(XlaBuilder* builder, + const string& call_target_name, + absl::Span operands, + const XlaComputation& computation, + const Shape& shape, const string& opaque = "", + bool has_side_effect = false); // Overload which constructs a custom call with fixed layouts. The operands will // have the layouts specified by |operand_shapes_with_layout| when provided to // external code, and the external code is expected to produce a result with the // layout specified by |shape_with_layout|. All shapes in |shape_with_layout| // and |operand_shapes_with_layout| must have layouts. -XlaOp CustomCallWithLayout( - XlaBuilder* builder, const string& call_target_name, - absl::Span operands, const Shape& shape_with_layout, - absl::Span operand_shapes_with_layout, - const string& opaque = "", bool has_side_effect = false, - absl::Span>> - output_operand_aliasing = {}); +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const string& opaque = "", + bool has_side_effect = false); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index b9aff298d35..170f7749f6a 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3492,8 +3492,6 @@ cc_library( hdrs = ["memory_space_assignment_utils.h"], deps = [ ":heap_simulator", - ":hlo", - ":hlo_casting_utils", ], ) diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index ac94b2e1d24..c3a7b3a5c14 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 75 +// Next ID: 74 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -232,11 +232,6 @@ message HloInstructionProto { // kCustomCall. bool custom_call_has_side_effect = 65; - // A list of CustomCallOutputOperandAliasing pairs that specifies aliasing - // buffers between output and operands for kCustomCall. - repeated xla.CustomCallOutputOperandAliasing - custom_call_output_operand_aliasing = 74; - // The delta value for kRngGetAndUpdateState. int64 delta = 66; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index bc1063f9d48..72899ffe163 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -432,23 +432,6 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { return changed; } -bool HloDataflowAnalysis::UpdateCustomCallValueSet( - HloInstruction* custom_call) { - CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall); - bool changed = false; - for (const auto& aliasing : Cast(custom_call) - ->output_to_operand_aliasing()) { - const HloValueSet& operand_value_set = GetValueSet( - custom_call->operand(aliasing.second.first), aliasing.second.second); - HloValueSet& value_set = GetValueSet(custom_call, aliasing.first); - if (value_set != operand_value_set) { - value_set = operand_value_set; - changed = true; - } - } - return changed; -} - bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) { CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart); bool changed = false; @@ -774,8 +757,6 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateAddDependencyValueSet(instruction); case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); - case HloOpcode::kCustomCall: - return UpdateCustomCallValueSet(instruction); case HloOpcode::kSetDimensionSize: return UpdateSetDimensionSizeValueSet(instruction); case HloOpcode::kDomain: @@ -1037,22 +1018,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_value_at(/*index=*/{1}); define_value_at(/*index=*/{2}); break; - case HloOpcode::kCustomCall: { - absl::flat_hash_set aliasing_indices; - for (const auto& aliasing : - Cast(instruction) - ->output_to_operand_aliasing()) { - aliasing_indices.insert(aliasing.first); - } - ShapeUtil::ForEachSubshape( - instruction->shape(), - [&](const Shape& /*subshape*/, const ShapeIndex& index) { - if (!aliasing_indices.contains(index)) { - define_value_at(index); - } - }); - break; - } default: define_all_values(); break; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index c3aad04023f..ffa307d71dd 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -216,7 +216,6 @@ class HloDataflowAnalysis { bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); - bool UpdateCustomCallValueSet(HloInstruction* custom_call); bool UpdateDomainValueSet(HloInstruction* domain); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 41488dcdaaa..251261a677f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -568,19 +568,6 @@ StatusOr> HloInstruction::CreateFromProto( std::max(static_cast(proto.batch_group_count()), int64{1})); custom_call_instr->set_custom_call_has_side_effect( proto.custom_call_has_side_effect()); - std::vector>> - output_to_operand_aliasing; - for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) { - output_to_operand_aliasing.emplace_back( - ShapeIndex(aliasing.output_shape_index().begin(), - aliasing.output_shape_index().end()), - std::pair{ - aliasing.operand_index(), - ShapeIndex(aliasing.operand_shape_index().begin(), - aliasing.operand_shape_index().end())}); - } - custom_call_instr->set_output_to_operand_aliasing( - std::move(output_to_operand_aliasing)); break; } case HloOpcode::kPad: diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 45b2d885d8e..c4c31dba9a4 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2395,16 +2395,6 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { } } proto.set_custom_call_has_side_effect(custom_call_has_side_effect_); - for (const auto& pair : output_to_operand_aliasing_) { - auto aliasing = proto.add_custom_call_output_operand_aliasing(); - aliasing->set_operand_index(pair.second.first); - for (int64 index : pair.first) { - aliasing->add_output_shape_index(index); - } - for (int64 index : pair.second.second) { - aliasing->add_operand_shape_index(index); - } - } return proto; } @@ -2442,16 +2432,6 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( if (custom_call_has_side_effect_) { extra.push_back("custom_call_has_side_effect=true"); } - if (!output_to_operand_aliasing_.empty()) { - std::vector pair_strings; - for (const auto& pair : output_to_operand_aliasing_) { - pair_strings.push_back(StrCat(pair.first.ToString(), ": (", - pair.second.first, ", ", - pair.second.second.ToString(), ")")); - } - extra.push_back(StrCat("output_to_operand_aliasing={", - StrJoin(pair_strings, ", "), "}")); - } return extra; } @@ -2495,10 +2475,6 @@ bool HloCustomCallInstruction::IdenticalSlowPath( casted_other.custom_call_has_side_effect()) { return false; } - if (output_to_operand_aliasing_ != - casted_other.output_to_operand_aliasing()) { - return false; - } // Note: backend_config comparison is done in Identical, which is the // intended/exposed way to compare computations, and so not repeated here. return custom_call_target_ == casted_other.custom_call_target_; @@ -2523,7 +2499,6 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( cloned->set_feature_group_count(feature_group_count_); cloned->set_batch_group_count(batch_group_count_); cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_); - cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_); return std::move(cloned); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 88e874347bd..821849bb02f 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1430,20 +1430,6 @@ class HloCustomCallInstruction : public HloInstruction { CHECK(layout_constrained()); return operand_shapes_with_layout_; } - // Gets a list of output/operand buffer pairs that alias each other, where the - // output buffer is represented as a ShapeIndex, and the operand buffer is - // represented as the operand index and the ShapeIndex. By default this list - // is empty. - const std::vector>>& - output_to_operand_aliasing() const { - return output_to_operand_aliasing_; - } - // Sets the list of output/operand buffer pairs that alias each other. - void set_output_to_operand_aliasing( - std::vector>> - aliasing) { - output_to_operand_aliasing_ = std::move(aliasing); - } private: std::vector ExtraAttributesToStringImpl( @@ -1472,10 +1458,6 @@ class HloCustomCallInstruction : public HloInstruction { std::vector operand_shapes_with_layout_; // Whether this custom call has a side-effect. bool custom_call_has_side_effect_; - // A list of output/operand buffer pairs that alias each other. See comment of - // output_to_operand_aliasing(). - std::vector>> - output_to_operand_aliasing_; }; class HloPadInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 37bdeaa1073..e2bbda3a607 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -212,7 +212,6 @@ class HloParserImpl : public HloParser { kEnum, kRandomAlgorithm, kAliasing, - kInstructionAliasing, }; struct AttrConfig { @@ -347,12 +346,6 @@ class HloParserImpl : public HloParser { // fails. bool ParseAliasing(AliasingData* data); - // Parses the per-instruction aliasing information from string `s`, returns - // `false` if it fails. - bool ParseInstructionOutputOperandAliasing( - std::vector>>* - aliasing_output_operand_pairs); - bool ParseShapeIndex(ShapeIndex* out); // Returns true if the current token is the beginning of a shape. @@ -605,58 +598,6 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) { return true; } -bool HloParserImpl::ParseInstructionOutputOperandAliasing( - std::vector>>* - aliasing_output_operand_pairs) { - if (!ParseToken( - TokKind::kLbrace, - "Expects '{' at the start of instruction aliasing description")) { - return false; - } - - while (lexer_.GetKind() != TokKind::kRbrace) { - ShapeIndex out; - if (!ParseShapeIndex(&out)) { - return false; - } - std::string errmsg = - "Expected format: : (, " - ")"; - if (!ParseToken(TokKind::kColon, errmsg)) { - return false; - } - - if (!ParseToken(TokKind::kLparen, errmsg)) { - return false; - } - int64 operand_index; - ParseInt64(&operand_index); - if (!ParseToken(TokKind::kComma, errmsg)) { - return false; - } - ShapeIndex operand_shape_index; - if (!ParseShapeIndex(&operand_shape_index)) { - return false; - } - - aliasing_output_operand_pairs->emplace_back( - out, std::pair{operand_index, operand_shape_index}); - if (!ParseToken(TokKind::kRparen, errmsg)) { - return false; - } - - if (!EatIfPresent(TokKind::kComma)) { - break; - } - } - if (!ParseToken( - TokKind::kRbrace, - "Expects '}' at the end of instruction aliasing description")) { - return false; - } - return true; -} - // ::= 'HloModule' name computations bool HloParserImpl::ParseHloModule(HloModule* module) { if (lexer_.GetKind() != TokKind::kw_HloModule) { @@ -1836,8 +1777,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, optional> operand_layout_constraints; optional custom_call_has_side_effect; optional to_apply; - optional>>> - output_to_operand_aliasing; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; @@ -1853,9 +1792,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, &custom_call_has_side_effect}; attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation, &to_apply}; - attrs["output_to_operand_aliasing"] = {/*required=*/false, - AttrTy::kInstructionAliasing, - &output_to_operand_aliasing}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1925,10 +1861,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, custom_call_instr->set_custom_call_has_side_effect( *custom_call_has_side_effect); } - if (output_to_operand_aliasing.has_value()) { - custom_call_instr->set_output_to_operand_aliasing( - std::move(*output_to_operand_aliasing)); - } break; } case HloOpcode::kDot: { @@ -3291,19 +3223,6 @@ bool HloParserImpl::ParseAttributeHelper( ->emplace(aliasing_data); return true; } - case AttrTy::kInstructionAliasing: { - std::vector>> - aliasing_output_operand_pairs; - if (!ParseInstructionOutputOperandAliasing( - &aliasing_output_operand_pairs)) { - return false; - } - static_cast>>>*>( - attr_out_ptr) - ->emplace(std::move(aliasing_output_operand_pairs)); - return true; - } } }(); if (!success) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index d220d735622..620e67c3a2f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -991,19 +991,6 @@ ENTRY %CustomCallWithHasSideEffect (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true } -)" -}, -// CustomCallWithAliasing -{ -"CustomCallWithAliasing", -R"(HloModule CustomCallWithAliasing - -ENTRY %CustomCallWithAliasing (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[123,4], f32[2,2], f32[1,2,3]) { - %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0) - %p1 = f32[123,4]{0,1} parameter(1) - ROOT %custom-call = (f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", output_to_operand_aliasing={{0}: (1, {}), {1}: (0, {0})} -} - )" }, // Parse c64 literal diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 4be0c5259cc..b3603e4fe5c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -801,28 +801,6 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); } } - for (const auto& pair : custom_call->output_to_operand_aliasing()) { - TF_RET_CHECK(pair.second.first < custom_call->operand_count()) - << "Invalid aliasing operand index."; - TF_RET_CHECK(ShapeUtil::IndexIsValid( - custom_call->operand(pair.second.first)->shape(), pair.second.second)) - << "Invalid aliasing operand shape index."; - TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first)) - << "Invalid aliasing output shape index."; - const Shape& output_subshape = - ShapeUtil::GetSubshape(custom_call->shape(), pair.first); - const Shape& operand_subshape = ShapeUtil::GetSubshape( - custom_call->operand(pair.second.first)->shape(), pair.second.second); - if (layout_sensitive_) { - TF_RET_CHECK(operand_subshape == output_subshape) - << "Different aliasing shapes: " << operand_subshape.ToString() - << " vs " << output_subshape.ToString(); - } else { - TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape)) - << "Different aliasing shapes: " << operand_subshape.ToString() - << " vs " << output_subshape.ToString(); - } - } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc index aad943aaad7..0c44ae0d766 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc @@ -15,9 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h" -#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" -#include "tensorflow/compiler/xla/service/hlo_instructions.h" - namespace xla { bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory( @@ -90,17 +87,6 @@ bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory( return false; } } - if (auto* custom_call = - DynCast(position.instruction)) { - for (const auto& pair : custom_call->output_to_operand_aliasing()) { - if (position.index == pair.first) { - VLOG(4) << "Keeping value " << value->ToShortString() - << " in default mem because it is a custom-call output that " - "aliases an operand buffer."; - return false; - } - } - } } return true; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 2d311dd2f70..d334f879c3e 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -691,11 +691,3 @@ message WhileLoopBackendConfig { // unknown-trip-count. KnownTripCount known_trip_count = 1; } - -// Specifies a pair of output/operand buffers for kCustomCall that alias each -// other. -message CustomCallOutputOperandAliasing { - repeated int64 output_shape_index = 1; - int64 operand_index = 2; - repeated int64 operand_shape_index = 3; -}