Allow CustomCall HLO to take a to_apply.

PiperOrigin-RevId: 293884304
Change-Id: I5349c3c4cbf956741c0ad052a564e753e78d367d
This commit is contained in:
A. Unique TensorFlower 2020-02-07 13:37:44 -08:00 committed by TensorFlower Gardener
parent aed2647941
commit 5fc54e2153
7 changed files with 64 additions and 6 deletions

View File

@ -66,6 +66,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
case HloOpcode::kSelectAndScatter:
case HloOpcode::kSort:
case HloOpcode::kFusion:
case HloOpcode::kCustomCall:
return CallContext::kParallel;
default:
return CallContext::kNone;

View File

@ -497,9 +497,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
operand_shapes, proto.backend_config());
} else {
instruction =
CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
proto.backend_config());
if (proto.called_computation_ids_size() == 1) {
instruction = CreateCustomCall(shape, all_operands(), computations(0),
proto.custom_call_target(),
proto.backend_config());
} else {
instruction = CreateCustomCall(shape, all_operands(),
proto.custom_call_target(),
proto.backend_config());
}
}
auto custom_call_instr =
Cast<HloCustomCallInstruction>(instruction.get());
@ -1408,6 +1414,14 @@ bool HloInstruction::HasSideEffect() const {
shape, operands, custom_call_target, std::move(opaque));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* to_apply, absl::string_view custom_call_target,
string opaque) {
return absl::make_unique<HloCustomCallInstruction>(
shape, operands, to_apply, custom_call_target, std::move(opaque));
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target,
@ -2150,6 +2164,7 @@ HloComputation* HloInstruction::to_apply() const {
case HloOpcode::kAllReduce:
case HloOpcode::kScatter:
case HloOpcode::kSort:
case HloOpcode::kCustomCall:
CHECK_EQ(called_computations_.size(), 1);
return called_computations_[0];
default:
@ -2170,6 +2185,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) {
case HloOpcode::kAllReduce:
case HloOpcode::kScatter:
case HloOpcode::kSort:
case HloOpcode::kCustomCall:
CHECK_EQ(called_computations_.size(), 1);
called_computations_[0] = computation;
break;

View File

@ -937,6 +937,12 @@ class HloInstruction {
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target, string opaque = "");
// Overload with a to_apply computation
static std::unique_ptr<HloInstruction> CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* to_apply, absl::string_view custom_call_target,
string opaque = "");
// Overload which constrains the layouts of the operand and result. 'shape'
// and 'operand_shapes_with_layout' must have layouts.
// 'operand_shapes_with_layout' must have a compatible element for each

View File

@ -2180,6 +2180,23 @@ HloCustomCallInstruction::HloCustomCallInstruction(
}
}
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
HloComputation* to_apply, absl::string_view custom_call_target,
string opaque)
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
feature_group_count_(1),
batch_group_count_(1),
layout_constrained_(false),
custom_call_has_side_effect_(false) {
set_raw_backend_config_string(std::move(opaque));
for (auto operand : operands) {
AppendOperand(operand);
}
AppendComputation(to_apply);
}
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target, string opaque,

View File

@ -1259,6 +1259,12 @@ class HloCustomCallInstruction : public HloInstruction {
absl::string_view custom_call_target, string opaque,
absl::Span<const Shape> operand_shapes_with_layout);
// Constructor for a custom call with a to_apply computation.
HloCustomCallInstruction(const Shape& shape,
absl::Span<HloInstruction* const> operands,
HloComputation* to_apply,
absl::string_view custom_call_target, string opaque);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;

View File

@ -1583,6 +1583,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
optional<int64> batch_group_count;
optional<std::vector<Shape>> operand_layout_constraints;
optional<bool> custom_call_has_side_effect;
optional<HloComputation*> to_apply;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
@ -1596,6 +1597,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
/*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
&custom_call_has_side_effect};
attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
&to_apply};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@ -1636,9 +1639,17 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
shape, operands, *custom_call_target, *operand_layout_constraints,
backend_config ? *backend_config : ""));
} else {
instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *custom_call_target,
backend_config ? *backend_config : ""));
if (to_apply.has_value()) {
instruction =
builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *to_apply, *custom_call_target,
backend_config ? *backend_config : ""));
} else {
instruction =
builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *custom_call_target,
backend_config ? *backend_config : ""));
}
}
auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
if (window.has_value()) {

View File

@ -46,6 +46,7 @@ bool IsCallerInstruction(HloInstruction* hlo) {
case HloOpcode::kSelectAndScatter:
case HloOpcode::kSort:
case HloOpcode::kFusion:
case HloOpcode::kCustomCall:
return true;
default:
return false;