Allow CustomCall HLO to take a to_apply.
PiperOrigin-RevId: 293884304 Change-Id: I5349c3c4cbf956741c0ad052a564e753e78d367d
This commit is contained in:
parent
aed2647941
commit
5fc54e2153
@ -66,6 +66,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kFusion:
|
||||
case HloOpcode::kCustomCall:
|
||||
return CallContext::kParallel;
|
||||
default:
|
||||
return CallContext::kNone;
|
||||
|
||||
@ -497,9 +497,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
|
||||
operand_shapes, proto.backend_config());
|
||||
} else {
|
||||
instruction =
|
||||
CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
|
||||
proto.backend_config());
|
||||
if (proto.called_computation_ids_size() == 1) {
|
||||
instruction = CreateCustomCall(shape, all_operands(), computations(0),
|
||||
proto.custom_call_target(),
|
||||
proto.backend_config());
|
||||
} else {
|
||||
instruction = CreateCustomCall(shape, all_operands(),
|
||||
proto.custom_call_target(),
|
||||
proto.backend_config());
|
||||
}
|
||||
}
|
||||
auto custom_call_instr =
|
||||
Cast<HloCustomCallInstruction>(instruction.get());
|
||||
@ -1408,6 +1414,14 @@ bool HloInstruction::HasSideEffect() const {
|
||||
shape, operands, custom_call_target, std::move(opaque));
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* to_apply, absl::string_view custom_call_target,
|
||||
string opaque) {
|
||||
return absl::make_unique<HloCustomCallInstruction>(
|
||||
shape, operands, to_apply, custom_call_target, std::move(opaque));
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::string_view custom_call_target,
|
||||
@ -2150,6 +2164,7 @@ HloComputation* HloInstruction::to_apply() const {
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kScatter:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kCustomCall:
|
||||
CHECK_EQ(called_computations_.size(), 1);
|
||||
return called_computations_[0];
|
||||
default:
|
||||
@ -2170,6 +2185,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) {
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kScatter:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kCustomCall:
|
||||
CHECK_EQ(called_computations_.size(), 1);
|
||||
called_computations_[0] = computation;
|
||||
break;
|
||||
|
||||
@ -937,6 +937,12 @@ class HloInstruction {
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::string_view custom_call_target, string opaque = "");
|
||||
|
||||
// Overload with a to_apply computation
|
||||
static std::unique_ptr<HloInstruction> CreateCustomCall(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* to_apply, absl::string_view custom_call_target,
|
||||
string opaque = "");
|
||||
|
||||
// Overload which constrains the layouts of the operand and result. 'shape'
|
||||
// and 'operand_shapes_with_layout' must have layouts.
|
||||
// 'operand_shapes_with_layout' must have a compatible element for each
|
||||
|
||||
@ -2180,6 +2180,23 @@ HloCustomCallInstruction::HloCustomCallInstruction(
|
||||
}
|
||||
}
|
||||
|
||||
HloCustomCallInstruction::HloCustomCallInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* to_apply, absl::string_view custom_call_target,
|
||||
string opaque)
|
||||
: HloInstruction(HloOpcode::kCustomCall, shape),
|
||||
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
|
||||
feature_group_count_(1),
|
||||
batch_group_count_(1),
|
||||
layout_constrained_(false),
|
||||
custom_call_has_side_effect_(false) {
|
||||
set_raw_backend_config_string(std::move(opaque));
|
||||
for (auto operand : operands) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
AppendComputation(to_apply);
|
||||
}
|
||||
|
||||
HloCustomCallInstruction::HloCustomCallInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::string_view custom_call_target, string opaque,
|
||||
|
||||
@ -1259,6 +1259,12 @@ class HloCustomCallInstruction : public HloInstruction {
|
||||
absl::string_view custom_call_target, string opaque,
|
||||
absl::Span<const Shape> operand_shapes_with_layout);
|
||||
|
||||
// Constructor for a custom call with a to_apply computation.
|
||||
HloCustomCallInstruction(const Shape& shape,
|
||||
absl::Span<HloInstruction* const> operands,
|
||||
HloComputation* to_apply,
|
||||
absl::string_view custom_call_target, string opaque);
|
||||
|
||||
const Window& window() const override {
|
||||
CHECK(window_ != nullptr);
|
||||
return *window_;
|
||||
|
||||
@ -1583,6 +1583,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
optional<int64> batch_group_count;
|
||||
optional<std::vector<Shape>> operand_layout_constraints;
|
||||
optional<bool> custom_call_has_side_effect;
|
||||
optional<HloComputation*> to_apply;
|
||||
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
|
||||
&custom_call_target};
|
||||
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
|
||||
@ -1596,6 +1597,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
/*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
|
||||
attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
|
||||
&custom_call_has_side_effect};
|
||||
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
|
||||
&to_apply};
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
@ -1636,9 +1639,17 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
shape, operands, *custom_call_target, *operand_layout_constraints,
|
||||
backend_config ? *backend_config : ""));
|
||||
} else {
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
shape, operands, *custom_call_target,
|
||||
backend_config ? *backend_config : ""));
|
||||
if (to_apply.has_value()) {
|
||||
instruction =
|
||||
builder->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
shape, operands, *to_apply, *custom_call_target,
|
||||
backend_config ? *backend_config : ""));
|
||||
} else {
|
||||
instruction =
|
||||
builder->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
shape, operands, *custom_call_target,
|
||||
backend_config ? *backend_config : ""));
|
||||
}
|
||||
}
|
||||
auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
|
||||
if (window.has_value()) {
|
||||
|
||||
@ -46,6 +46,7 @@ bool IsCallerInstruction(HloInstruction* hlo) {
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kFusion:
|
||||
case HloOpcode::kCustomCall:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user