Support multiple computations in custom call.
PiperOrigin-RevId: 340503071 Change-Id: Id9baa9795d2f5a48acd59afefd544f0cf7b7ecdb
This commit is contained in:
parent
5feee72aff
commit
6561045a9b
@ -560,6 +560,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
instruction = CreateCustomCall(shape, all_operands(), computations(0),
|
||||
proto.custom_call_target(),
|
||||
proto.backend_config());
|
||||
} else if (proto.called_computation_ids_size() > 1) {
|
||||
instruction = CreateCustomCall(
|
||||
shape, all_operands(), all_computations(),
|
||||
proto.custom_call_target(), proto.backend_config());
|
||||
|
||||
} else {
|
||||
instruction = CreateCustomCall(shape, all_operands(),
|
||||
proto.custom_call_target(),
|
||||
@ -1564,6 +1569,15 @@ bool HloInstruction::HasSideEffect() const {
|
||||
shape, operands, to_apply, custom_call_target, std::move(opaque));
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::Span<HloComputation* const> called_computations,
|
||||
absl::string_view custom_call_target, string opaque) {
|
||||
return absl::make_unique<HloCustomCallInstruction>(
|
||||
shape, operands, called_computations, custom_call_target,
|
||||
std::move(opaque));
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::string_view custom_call_target,
|
||||
@ -2802,6 +2816,17 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
|
||||
opcode() == HloOpcode::kSort) {
|
||||
extra.push_back(
|
||||
StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
|
||||
} else if (opcode() == HloOpcode::kCustomCall) {
|
||||
if (!called_computations().empty()) {
|
||||
extra.push_back(StrCat(
|
||||
"called_computations={",
|
||||
StrJoin(called_computations(), ", ",
|
||||
[&](string* out, const HloComputation* computation) {
|
||||
StrAppend(
|
||||
out, PrintNameInternal(computation->name(), options));
|
||||
}),
|
||||
"}"));
|
||||
}
|
||||
} else if (!called_computations().empty()) {
|
||||
extra.push_back(StrCat(
|
||||
"calls=",
|
||||
|
@ -983,12 +983,19 @@ class HloInstruction {
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::string_view custom_call_target, string opaque = "");
|
||||
|
||||
// Overload with a to_apply computation
|
||||
// Overload with a to_apply computation.
|
||||
static std::unique_ptr<HloInstruction> CreateCustomCall(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* to_apply, absl::string_view custom_call_target,
|
||||
string opaque = "");
|
||||
|
||||
// Overload with multiple computations. The called computations can have
|
||||
// different function signatures.
|
||||
static std::unique_ptr<HloInstruction> CreateCustomCall(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::Span<HloComputation* const> called_computations,
|
||||
absl::string_view custom_call_target, string opaque = "");
|
||||
|
||||
// Overload which constrains the layouts of the operand and result. 'shape'
|
||||
// and 'operand_shapes_with_layout' must have layouts.
|
||||
// 'operand_shapes_with_layout' must have a compatible element for each
|
||||
|
@ -2371,6 +2371,26 @@ HloCustomCallInstruction::HloCustomCallInstruction(
|
||||
AppendComputation(to_apply);
|
||||
}
|
||||
|
||||
HloCustomCallInstruction::HloCustomCallInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::Span<HloComputation* const> called_computations,
|
||||
absl::string_view custom_call_target, string opaque)
|
||||
: HloInstruction(HloOpcode::kCustomCall, shape),
|
||||
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
|
||||
feature_group_count_(1),
|
||||
batch_group_count_(1),
|
||||
layout_constrained_(false),
|
||||
padding_type_(PaddingType::PADDING_INVALID),
|
||||
custom_call_has_side_effect_(false) {
|
||||
set_raw_backend_config_string(std::move(opaque));
|
||||
for (auto operand : operands) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
for (auto comp : called_computations) {
|
||||
AppendComputation(comp);
|
||||
}
|
||||
}
|
||||
|
||||
HloCustomCallInstruction::HloCustomCallInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::string_view custom_call_target, string opaque,
|
||||
@ -2531,6 +2551,17 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
|
||||
casted_other.precision_config())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (called_computations().size() != other.called_computations().size()) {
|
||||
return false;
|
||||
}
|
||||
for (int64 i = 0; i < called_computations().size(); ++i) {
|
||||
if (!eq_computations(called_computations()[i],
|
||||
other.called_computations()[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Note: backend_config comparison is done in Identical, which is the
|
||||
// intended/exposed way to compare computations, and so not repeated here.
|
||||
return custom_call_target_ == casted_other.custom_call_target_;
|
||||
|
@ -1415,6 +1415,12 @@ class HloCustomCallInstruction : public HloInstruction {
|
||||
HloComputation* to_apply,
|
||||
absl::string_view custom_call_target, string opaque);
|
||||
|
||||
// Constructor for a custom call with multiple computations.
|
||||
HloCustomCallInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::Span<HloComputation* const> called_computations,
|
||||
absl::string_view custom_call_target, string opaque);
|
||||
|
||||
const Window& window() const override {
|
||||
CHECK(window_ != nullptr);
|
||||
return *window_;
|
||||
|
@ -1841,6 +1841,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
|
||||
output_to_operand_aliasing;
|
||||
optional<PaddingType> padding_type;
|
||||
optional<std::vector<HloComputation*>> called_computations;
|
||||
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
|
||||
&custom_call_target};
|
||||
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
|
||||
@ -1856,6 +1857,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
&custom_call_has_side_effect};
|
||||
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
|
||||
&to_apply};
|
||||
attrs["called_computations"] = {/*required=*/false,
|
||||
AttrTy::kBracedHloComputationList,
|
||||
&called_computations};
|
||||
attrs["output_to_operand_aliasing"] = {/*required=*/false,
|
||||
AttrTy::kInstructionAliasing,
|
||||
&output_to_operand_aliasing};
|
||||
@ -1865,6 +1869,11 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
|
||||
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
|
||||
&operand_precision};
|
||||
if (called_computations.has_value() && to_apply.has_value()) {
|
||||
return Error(lexer_.GetLoc(),
|
||||
"A single instruction can't have both to_apply and "
|
||||
"calls field");
|
||||
}
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
@ -1910,6 +1919,11 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
builder->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
shape, operands, *to_apply, *custom_call_target,
|
||||
backend_config ? *backend_config : ""));
|
||||
} else if (called_computations.has_value()) {
|
||||
instruction =
|
||||
builder->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
shape, operands, *called_computations, *custom_call_target,
|
||||
backend_config ? *backend_config : ""));
|
||||
} else {
|
||||
instruction =
|
||||
builder->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
|
@ -1395,6 +1395,42 @@ ENTRY CustomCall {
|
||||
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// CustomCall with single computation.
|
||||
{
|
||||
"CustumCallSingleComp",
|
||||
R"(HloModule custom_call_with_comp
|
||||
|
||||
max_F32 {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT maximum = f32[] maximum(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY CustomCall {
|
||||
constant = f32[1]{0} constant({12345})
|
||||
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32}
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// CustomCall with multiple computations.
|
||||
{
|
||||
"CustumCallMultipleComps",
|
||||
R"(HloModule custom_call_with_comps
|
||||
|
||||
max_F32 {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT maximum = f32[] maximum(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY CustomCall {
|
||||
constant = f32[1]{0} constant({12345})
|
||||
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32, max_F32}
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// Variables with non-default names
|
||||
|
Loading…
x
Reference in New Issue
Block a user