Support multiple computations in custom call.
PiperOrigin-RevId: 340503071 Change-Id: Id9baa9795d2f5a48acd59afefd544f0cf7b7ecdb
This commit is contained in:
parent
5feee72aff
commit
6561045a9b
@ -560,6 +560,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
instruction = CreateCustomCall(shape, all_operands(), computations(0),
|
instruction = CreateCustomCall(shape, all_operands(), computations(0),
|
||||||
proto.custom_call_target(),
|
proto.custom_call_target(),
|
||||||
proto.backend_config());
|
proto.backend_config());
|
||||||
|
} else if (proto.called_computation_ids_size() > 1) {
|
||||||
|
instruction = CreateCustomCall(
|
||||||
|
shape, all_operands(), all_computations(),
|
||||||
|
proto.custom_call_target(), proto.backend_config());
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
instruction = CreateCustomCall(shape, all_operands(),
|
instruction = CreateCustomCall(shape, all_operands(),
|
||||||
proto.custom_call_target(),
|
proto.custom_call_target(),
|
||||||
@ -1564,6 +1569,15 @@ bool HloInstruction::HasSideEffect() const {
|
|||||||
shape, operands, to_apply, custom_call_target, std::move(opaque));
|
shape, operands, to_apply, custom_call_target, std::move(opaque));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
|
||||||
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
|
absl::Span<HloComputation* const> called_computations,
|
||||||
|
absl::string_view custom_call_target, string opaque) {
|
||||||
|
return absl::make_unique<HloCustomCallInstruction>(
|
||||||
|
shape, operands, called_computations, custom_call_target,
|
||||||
|
std::move(opaque));
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
|
||||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
absl::string_view custom_call_target,
|
absl::string_view custom_call_target,
|
||||||
@ -2802,6 +2816,17 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
|
|||||||
opcode() == HloOpcode::kSort) {
|
opcode() == HloOpcode::kSort) {
|
||||||
extra.push_back(
|
extra.push_back(
|
||||||
StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
|
StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
|
||||||
|
} else if (opcode() == HloOpcode::kCustomCall) {
|
||||||
|
if (!called_computations().empty()) {
|
||||||
|
extra.push_back(StrCat(
|
||||||
|
"called_computations={",
|
||||||
|
StrJoin(called_computations(), ", ",
|
||||||
|
[&](string* out, const HloComputation* computation) {
|
||||||
|
StrAppend(
|
||||||
|
out, PrintNameInternal(computation->name(), options));
|
||||||
|
}),
|
||||||
|
"}"));
|
||||||
|
}
|
||||||
} else if (!called_computations().empty()) {
|
} else if (!called_computations().empty()) {
|
||||||
extra.push_back(StrCat(
|
extra.push_back(StrCat(
|
||||||
"calls=",
|
"calls=",
|
||||||
|
@ -983,12 +983,19 @@ class HloInstruction {
|
|||||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
absl::string_view custom_call_target, string opaque = "");
|
absl::string_view custom_call_target, string opaque = "");
|
||||||
|
|
||||||
// Overload with a to_apply computation
|
// Overload with a to_apply computation.
|
||||||
static std::unique_ptr<HloInstruction> CreateCustomCall(
|
static std::unique_ptr<HloInstruction> CreateCustomCall(
|
||||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
HloComputation* to_apply, absl::string_view custom_call_target,
|
HloComputation* to_apply, absl::string_view custom_call_target,
|
||||||
string opaque = "");
|
string opaque = "");
|
||||||
|
|
||||||
|
// Overload with multiple computations. The called computations can have
|
||||||
|
// different function signatures.
|
||||||
|
static std::unique_ptr<HloInstruction> CreateCustomCall(
|
||||||
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
|
absl::Span<HloComputation* const> called_computations,
|
||||||
|
absl::string_view custom_call_target, string opaque = "");
|
||||||
|
|
||||||
// Overload which constrains the layouts of the operand and result. 'shape'
|
// Overload which constrains the layouts of the operand and result. 'shape'
|
||||||
// and 'operand_shapes_with_layout' must have layouts.
|
// and 'operand_shapes_with_layout' must have layouts.
|
||||||
// 'operand_shapes_with_layout' must have a compatible element for each
|
// 'operand_shapes_with_layout' must have a compatible element for each
|
||||||
|
@ -2371,6 +2371,26 @@ HloCustomCallInstruction::HloCustomCallInstruction(
|
|||||||
AppendComputation(to_apply);
|
AppendComputation(to_apply);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloCustomCallInstruction::HloCustomCallInstruction(
|
||||||
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
|
absl::Span<HloComputation* const> called_computations,
|
||||||
|
absl::string_view custom_call_target, string opaque)
|
||||||
|
: HloInstruction(HloOpcode::kCustomCall, shape),
|
||||||
|
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
|
||||||
|
feature_group_count_(1),
|
||||||
|
batch_group_count_(1),
|
||||||
|
layout_constrained_(false),
|
||||||
|
padding_type_(PaddingType::PADDING_INVALID),
|
||||||
|
custom_call_has_side_effect_(false) {
|
||||||
|
set_raw_backend_config_string(std::move(opaque));
|
||||||
|
for (auto operand : operands) {
|
||||||
|
AppendOperand(operand);
|
||||||
|
}
|
||||||
|
for (auto comp : called_computations) {
|
||||||
|
AppendComputation(comp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
HloCustomCallInstruction::HloCustomCallInstruction(
|
HloCustomCallInstruction::HloCustomCallInstruction(
|
||||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
absl::string_view custom_call_target, string opaque,
|
absl::string_view custom_call_target, string opaque,
|
||||||
@ -2531,6 +2551,17 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
|
|||||||
casted_other.precision_config())) {
|
casted_other.precision_config())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (called_computations().size() != other.called_computations().size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int64 i = 0; i < called_computations().size(); ++i) {
|
||||||
|
if (!eq_computations(called_computations()[i],
|
||||||
|
other.called_computations()[i])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Note: backend_config comparison is done in Identical, which is the
|
// Note: backend_config comparison is done in Identical, which is the
|
||||||
// intended/exposed way to compare computations, and so not repeated here.
|
// intended/exposed way to compare computations, and so not repeated here.
|
||||||
return custom_call_target_ == casted_other.custom_call_target_;
|
return custom_call_target_ == casted_other.custom_call_target_;
|
||||||
|
@ -1415,6 +1415,12 @@ class HloCustomCallInstruction : public HloInstruction {
|
|||||||
HloComputation* to_apply,
|
HloComputation* to_apply,
|
||||||
absl::string_view custom_call_target, string opaque);
|
absl::string_view custom_call_target, string opaque);
|
||||||
|
|
||||||
|
// Constructor for a custom call with multiple computations.
|
||||||
|
HloCustomCallInstruction(
|
||||||
|
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||||
|
absl::Span<HloComputation* const> called_computations,
|
||||||
|
absl::string_view custom_call_target, string opaque);
|
||||||
|
|
||||||
const Window& window() const override {
|
const Window& window() const override {
|
||||||
CHECK(window_ != nullptr);
|
CHECK(window_ != nullptr);
|
||||||
return *window_;
|
return *window_;
|
||||||
|
@ -1841,6 +1841,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
|
optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
|
||||||
output_to_operand_aliasing;
|
output_to_operand_aliasing;
|
||||||
optional<PaddingType> padding_type;
|
optional<PaddingType> padding_type;
|
||||||
|
optional<std::vector<HloComputation*>> called_computations;
|
||||||
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
|
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
|
||||||
&custom_call_target};
|
&custom_call_target};
|
||||||
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
|
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
|
||||||
@ -1856,6 +1857,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
&custom_call_has_side_effect};
|
&custom_call_has_side_effect};
|
||||||
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
|
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
|
||||||
&to_apply};
|
&to_apply};
|
||||||
|
attrs["called_computations"] = {/*required=*/false,
|
||||||
|
AttrTy::kBracedHloComputationList,
|
||||||
|
&called_computations};
|
||||||
attrs["output_to_operand_aliasing"] = {/*required=*/false,
|
attrs["output_to_operand_aliasing"] = {/*required=*/false,
|
||||||
AttrTy::kInstructionAliasing,
|
AttrTy::kInstructionAliasing,
|
||||||
&output_to_operand_aliasing};
|
&output_to_operand_aliasing};
|
||||||
@ -1865,6 +1869,11 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
|
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
|
||||||
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
|
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
|
||||||
&operand_precision};
|
&operand_precision};
|
||||||
|
if (called_computations.has_value() && to_apply.has_value()) {
|
||||||
|
return Error(lexer_.GetLoc(),
|
||||||
|
"A single instruction can't have both to_apply and "
|
||||||
|
"calls field");
|
||||||
|
}
|
||||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -1910,6 +1919,11 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
builder->AddInstruction(HloInstruction::CreateCustomCall(
|
builder->AddInstruction(HloInstruction::CreateCustomCall(
|
||||||
shape, operands, *to_apply, *custom_call_target,
|
shape, operands, *to_apply, *custom_call_target,
|
||||||
backend_config ? *backend_config : ""));
|
backend_config ? *backend_config : ""));
|
||||||
|
} else if (called_computations.has_value()) {
|
||||||
|
instruction =
|
||||||
|
builder->AddInstruction(HloInstruction::CreateCustomCall(
|
||||||
|
shape, operands, *called_computations, *custom_call_target,
|
||||||
|
backend_config ? *backend_config : ""));
|
||||||
} else {
|
} else {
|
||||||
instruction =
|
instruction =
|
||||||
builder->AddInstruction(HloInstruction::CreateCustomCall(
|
builder->AddInstruction(HloInstruction::CreateCustomCall(
|
||||||
|
@ -1395,6 +1395,42 @@ ENTRY CustomCall {
|
|||||||
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
|
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
)"
|
||||||
|
},
|
||||||
|
// CustomCall with single computation.
|
||||||
|
{
|
||||||
|
"CustumCallSingleComp",
|
||||||
|
R"(HloModule custom_call_with_comp
|
||||||
|
|
||||||
|
max_F32 {
|
||||||
|
lhs = f32[] parameter(0)
|
||||||
|
rhs = f32[] parameter(1)
|
||||||
|
ROOT maximum = f32[] maximum(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY CustomCall {
|
||||||
|
constant = f32[1]{0} constant({12345})
|
||||||
|
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32}
|
||||||
|
}
|
||||||
|
|
||||||
|
)"
|
||||||
|
},
|
||||||
|
// CustomCall with multiple computations.
|
||||||
|
{
|
||||||
|
"CustumCallMultipleComps",
|
||||||
|
R"(HloModule custom_call_with_comps
|
||||||
|
|
||||||
|
max_F32 {
|
||||||
|
lhs = f32[] parameter(0)
|
||||||
|
rhs = f32[] parameter(1)
|
||||||
|
ROOT maximum = f32[] maximum(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY CustomCall {
|
||||||
|
constant = f32[1]{0} constant({12345})
|
||||||
|
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32, max_F32}
|
||||||
|
}
|
||||||
|
|
||||||
)"
|
)"
|
||||||
},
|
},
|
||||||
// Variables with non-default names
|
// Variables with non-default names
|
||||||
|
Loading…
x
Reference in New Issue
Block a user