From 6561045a9b76e1549aa618dacf58e152dfdd0f3b Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Tue, 3 Nov 2020 12:39:25 -0800 Subject: [PATCH] Support multiple computations in custom call. PiperOrigin-RevId: 340503071 Change-Id: Id9baa9795d2f5a48acd59afefd544f0cf7b7ecdb --- .../compiler/xla/service/hlo_instruction.cc | 25 +++++++++++++ .../compiler/xla/service/hlo_instruction.h | 9 ++++- .../compiler/xla/service/hlo_instructions.cc | 31 ++++++++++++++++ .../compiler/xla/service/hlo_instructions.h | 6 ++++ tensorflow/compiler/xla/service/hlo_parser.cc | 14 ++++++++ .../compiler/xla/service/hlo_parser_test.cc | 36 +++++++++++++++++++ 6 files changed, 120 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index bc6dc23eea7..f21c8201dfe 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -560,6 +560,11 @@ StatusOr> HloInstruction::CreateFromProto( instruction = CreateCustomCall(shape, all_operands(), computations(0), proto.custom_call_target(), proto.backend_config()); + } else if (proto.called_computation_ids_size() > 1) { + instruction = CreateCustomCall( + shape, all_operands(), all_computations(), + proto.custom_call_target(), proto.backend_config()); + } else { instruction = CreateCustomCall(shape, all_operands(), proto.custom_call_target(), @@ -1564,6 +1569,15 @@ bool HloInstruction::HasSideEffect() const { shape, operands, to_apply, custom_call_target, std::move(opaque)); } +/* static */ std::unique_ptr HloInstruction::CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::Span called_computations, + absl::string_view custom_call_target, string opaque) { + return absl::make_unique( + shape, operands, called_computations, custom_call_target, + std::move(opaque)); +} + /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, absl::Span operands, absl::string_view custom_call_target, @@ -2802,6 +2816,17 @@ std::vector HloInstruction::ExtraAttributesToString( opcode() == HloOpcode::kSort) { extra.push_back( StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options))); + } else if (opcode() == HloOpcode::kCustomCall) { + if (!called_computations().empty()) { + extra.push_back(StrCat( + "called_computations={", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend( + out, PrintNameInternal(computation->name(), options)); + }), + "}")); + } } else if (!called_computations().empty()) { extra.push_back(StrCat( "calls=", diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8f032c0b184..1dcaeb4e114 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -983,12 +983,19 @@ class HloInstruction { const Shape& shape, absl::Span operands, absl::string_view custom_call_target, string opaque = ""); - // Overload with a to_apply computation + // Overload with a to_apply computation. static std::unique_ptr CreateCustomCall( const Shape& shape, absl::Span operands, HloComputation* to_apply, absl::string_view custom_call_target, string opaque = ""); + // Overload with multiple computations. The called computations can have + // different function signatures. + static std::unique_ptr CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::Span called_computations, + absl::string_view custom_call_target, string opaque = ""); + // Overload which constrains the layouts of the operand and result. 'shape' // and 'operand_shapes_with_layout' must have layouts. // 'operand_shapes_with_layout' must have a compatible element for each diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 5d5b62359e0..7b84e6e0700 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2371,6 +2371,26 @@ HloCustomCallInstruction::HloCustomCallInstruction( AppendComputation(to_apply); } +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, absl::Span operands, + absl::Span called_computations, + absl::string_view custom_call_target, string opaque) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + feature_group_count_(1), + batch_group_count_(1), + layout_constrained_(false), + padding_type_(PaddingType::PADDING_INVALID), + custom_call_has_side_effect_(false) { + set_raw_backend_config_string(std::move(opaque)); + for (auto operand : operands) { + AppendOperand(operand); + } + for (auto comp : called_computations) { + AppendComputation(comp); + } +} + HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, absl::Span operands, absl::string_view custom_call_target, string opaque, @@ -2531,6 +2551,17 @@ bool HloCustomCallInstruction::IdenticalSlowPath( casted_other.precision_config())) { return false; } + + if (called_computations().size() != other.called_computations().size()) { + return false; + } + for (int64 i = 0; i < called_computations().size(); ++i) { + if (!eq_computations(called_computations()[i], + other.called_computations()[i])) { + return false; + } + } + // Note: backend_config comparison is done in Identical, which is the // intended/exposed way to compare computations, and so not repeated here. return custom_call_target_ == casted_other.custom_call_target_; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 15c5fbd276d..bacbce15206 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1415,6 +1415,12 @@ class HloCustomCallInstruction : public HloInstruction { HloComputation* to_apply, absl::string_view custom_call_target, string opaque); + // Constructor for a custom call with multiple computations. + HloCustomCallInstruction( + const Shape& shape, absl::Span operands, + absl::Span called_computations, + absl::string_view custom_call_target, string opaque); + const Window& window() const override { CHECK(window_ != nullptr); return *window_; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index e0072d91738..1dc505e2e60 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1841,6 +1841,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, optional>>> output_to_operand_aliasing; optional padding_type; + optional> called_computations; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; @@ -1856,6 +1857,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, &custom_call_has_side_effect}; attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation, &to_apply}; + attrs["called_computations"] = {/*required=*/false, + AttrTy::kBracedHloComputationList, + &called_computations}; attrs["output_to_operand_aliasing"] = {/*required=*/false, AttrTy::kInstructionAliasing, &output_to_operand_aliasing}; @@ -1865,6 +1869,11 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, optional> operand_precision; attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, &operand_precision}; + if (called_computations.has_value() && to_apply.has_value()) { + return Error(lexer_.GetLoc(), + "A single instruction can't have both to_apply and " + "calls field"); + } if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1910,6 +1919,11 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, builder->AddInstruction(HloInstruction::CreateCustomCall( shape, operands, *to_apply, *custom_call_target, backend_config ? *backend_config : "")); + } else if (called_computations.has_value()) { + instruction = + builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *called_computations, *custom_call_target, + backend_config ? *backend_config : "")); } else { instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index fd7ce24395e..7e2d009c7f9 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1395,6 +1395,42 @@ ENTRY CustomCall { ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar" } +)" +}, +// CustomCall with single computation. +{ +"CustumCallSingleComp", +R"(HloModule custom_call_with_comp + +max_F32 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT maximum = f32[] maximum(lhs, rhs) +} + +ENTRY CustomCall { + constant = f32[1]{0} constant({12345}) + ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32} +} + +)" +}, +// CustomCall with multiple computations. +{ +"CustumCallMultipleComps", +R"(HloModule custom_call_with_comps + +max_F32 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT maximum = f32[] maximum(lhs, rhs) +} + +ENTRY CustomCall { + constant = f32[1]{0} constant({12345}) + ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32, max_F32} +} + )" }, // Variables with non-default names