Support multiple computations in custom call.

PiperOrigin-RevId: 340503071
Change-Id: Id9baa9795d2f5a48acd59afefd544f0cf7b7ecdb
This commit is contained in:
Yunxing Dai 2020-11-03 12:39:25 -08:00 committed by TensorFlower Gardener
parent 5feee72aff
commit 6561045a9b
6 changed files with 120 additions and 1 deletions

View File

@ -560,6 +560,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction = CreateCustomCall(shape, all_operands(), computations(0),
proto.custom_call_target(),
proto.backend_config());
} else if (proto.called_computation_ids_size() > 1) {
instruction = CreateCustomCall(
shape, all_operands(), all_computations(),
proto.custom_call_target(), proto.backend_config());
} else {
instruction = CreateCustomCall(shape, all_operands(),
proto.custom_call_target(),
@ -1564,6 +1569,15 @@ bool HloInstruction::HasSideEffect() const {
shape, operands, to_apply, custom_call_target, std::move(opaque));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloComputation* const> called_computations,
absl::string_view custom_call_target, string opaque) {
return absl::make_unique<HloCustomCallInstruction>(
shape, operands, called_computations, custom_call_target,
std::move(opaque));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target,
@ -2802,6 +2816,17 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
opcode() == HloOpcode::kSort) {
extra.push_back(
StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
} else if (opcode() == HloOpcode::kCustomCall) {
if (!called_computations().empty()) {
extra.push_back(StrCat(
"called_computations={",
StrJoin(called_computations(), ", ",
[&](string* out, const HloComputation* computation) {
StrAppend(
out, PrintNameInternal(computation->name(), options));
}),
"}"));
}
} else if (!called_computations().empty()) {
extra.push_back(StrCat(
"calls=",

View File

@ -983,12 +983,19 @@ class HloInstruction {
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target, string opaque = "");
// Overload with a to_apply computation
// Overload with a to_apply computation.
static std::unique_ptr<HloInstruction> CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* to_apply, absl::string_view custom_call_target,
string opaque = "");
// Overload with multiple computations. The called computations can have
// different function signatures.
static std::unique_ptr<HloInstruction> CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloComputation* const> called_computations,
absl::string_view custom_call_target, string opaque = "");
// Overload which constrains the layouts of the operand and result. 'shape'
// and 'operand_shapes_with_layout' must have layouts.
// 'operand_shapes_with_layout' must have a compatible element for each

View File

@ -2371,6 +2371,26 @@ HloCustomCallInstruction::HloCustomCallInstruction(
AppendComputation(to_apply);
}
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloComputation* const> called_computations,
absl::string_view custom_call_target, string opaque)
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
feature_group_count_(1),
batch_group_count_(1),
layout_constrained_(false),
padding_type_(PaddingType::PADDING_INVALID),
custom_call_has_side_effect_(false) {
set_raw_backend_config_string(std::move(opaque));
for (auto operand : operands) {
AppendOperand(operand);
}
for (auto comp : called_computations) {
AppendComputation(comp);
}
}
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target, string opaque,
@ -2531,6 +2551,17 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
casted_other.precision_config())) {
return false;
}
if (called_computations().size() != other.called_computations().size()) {
return false;
}
for (int64 i = 0; i < called_computations().size(); ++i) {
if (!eq_computations(called_computations()[i],
other.called_computations()[i])) {
return false;
}
}
// Note: backend_config comparison is done in Identical, which is the
// intended/exposed way to compare computations, and so not repeated here.
return custom_call_target_ == casted_other.custom_call_target_;

View File

@ -1415,6 +1415,12 @@ class HloCustomCallInstruction : public HloInstruction {
HloComputation* to_apply,
absl::string_view custom_call_target, string opaque);
// Constructor for a custom call with multiple computations.
HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloComputation* const> called_computations,
absl::string_view custom_call_target, string opaque);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;

View File

@ -1841,6 +1841,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
optional<std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>>
output_to_operand_aliasing;
optional<PaddingType> padding_type;
optional<std::vector<HloComputation*>> called_computations;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
@ -1856,6 +1857,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
&custom_call_has_side_effect};
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
&to_apply};
attrs["called_computations"] = {/*required=*/false,
AttrTy::kBracedHloComputationList,
&called_computations};
attrs["output_to_operand_aliasing"] = {/*required=*/false,
AttrTy::kInstructionAliasing,
&output_to_operand_aliasing};
@ -1865,6 +1869,11 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
if (called_computations.has_value() && to_apply.has_value()) {
return Error(lexer_.GetLoc(),
"A single instruction can't have both to_apply and "
"calls field");
}
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@ -1910,6 +1919,11 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *to_apply, *custom_call_target,
backend_config ? *backend_config : ""));
} else if (called_computations.has_value()) {
instruction =
builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *called_computations, *custom_call_target,
backend_config ? *backend_config : ""));
} else {
instruction =
builder->AddInstruction(HloInstruction::CreateCustomCall(

View File

@ -1395,6 +1395,42 @@ ENTRY CustomCall {
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
}
)"
},
// CustomCall with single computation.
{
"CustumCallSingleComp",
R"(HloModule custom_call_with_comp
max_F32 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT maximum = f32[] maximum(lhs, rhs)
}
ENTRY CustomCall {
constant = f32[1]{0} constant({12345})
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32}
}
)"
},
// CustomCall with multiple computations.
{
"CustumCallMultipleComps",
R"(HloModule custom_call_with_comps
max_F32 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT maximum = f32[] maximum(lhs, rhs)
}
ENTRY CustomCall {
constant = f32[1]{0} constant({12345})
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32, max_F32}
}
)"
},
// Variables with non-default names