diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 968301fd5df..949bc13fe9a 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -66,6 +66,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kSelectAndScatter: case HloOpcode::kSort: case HloOpcode::kFusion: + case HloOpcode::kCustomCall: return CallContext::kParallel; default: return CallContext::kNone; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 06e25d3d0be..9f45cac028c 100755 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -497,9 +497,15 @@ StatusOr> HloInstruction::CreateFromProto( CreateCustomCall(shape, all_operands(), proto.custom_call_target(), operand_shapes, proto.backend_config()); } else { - instruction = - CreateCustomCall(shape, all_operands(), proto.custom_call_target(), - proto.backend_config()); + if (proto.called_computation_ids_size() == 1) { + instruction = CreateCustomCall(shape, all_operands(), computations(0), + proto.custom_call_target(), + proto.backend_config()); + } else { + instruction = CreateCustomCall(shape, all_operands(), + proto.custom_call_target(), + proto.backend_config()); + } } auto custom_call_instr = Cast(instruction.get()); @@ -1408,6 +1414,14 @@ bool HloInstruction::HasSideEffect() const { shape, operands, custom_call_target, std::move(opaque)); } +/* static */ std::unique_ptr HloInstruction::CreateCustomCall( + const Shape& shape, absl::Span operands, + HloComputation* to_apply, absl::string_view custom_call_target, + string opaque) { + return absl::make_unique( + shape, operands, to_apply, custom_call_target, std::move(opaque)); +} + /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, absl::Span operands, absl::string_view custom_call_target, @@ -2150,6 +2164,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kAllReduce: case HloOpcode::kScatter: case HloOpcode::kSort: + case HloOpcode::kCustomCall: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -2170,6 +2185,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kAllReduce: case HloOpcode::kScatter: case HloOpcode::kSort: + case HloOpcode::kCustomCall: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8382af96a9b..a108a91d5f9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -937,6 +937,12 @@ class HloInstruction { const Shape& shape, absl::Span operands, absl::string_view custom_call_target, string opaque = ""); + // Overload with a to_apply computation + static std::unique_ptr CreateCustomCall( + const Shape& shape, absl::Span operands, + HloComputation* to_apply, absl::string_view custom_call_target, + string opaque = ""); + // Overload which constrains the layouts of the operand and result. 'shape' // and 'operand_shapes_with_layout' must have layouts. // 'operand_shapes_with_layout' must have a compatible element for each diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index bb5597549ff..1a30062a574 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2180,6 +2180,23 @@ HloCustomCallInstruction::HloCustomCallInstruction( } } +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, absl::Span operands, + HloComputation* to_apply, absl::string_view custom_call_target, + string opaque) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + feature_group_count_(1), + batch_group_count_(1), + layout_constrained_(false), + custom_call_has_side_effect_(false) { + set_raw_backend_config_string(std::move(opaque)); + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(to_apply); +} + HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, absl::Span operands, absl::string_view custom_call_target, string opaque, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 36a493efe34..f23453bc0be 100755 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1259,6 +1259,12 @@ class HloCustomCallInstruction : public HloInstruction { absl::string_view custom_call_target, string opaque, absl::Span operand_shapes_with_layout); + // Constructor for a custom call with a to_apply computation. + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + HloComputation* to_apply, + absl::string_view custom_call_target, string opaque); + const Window& window() const override { CHECK(window_ != nullptr); return *window_; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 11890cee91c..95a18c8daa7 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1583,6 +1583,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, optional batch_group_count; optional> operand_layout_constraints; optional custom_call_has_side_effect; + optional to_apply; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; @@ -1596,6 +1597,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints}; attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool, &custom_call_has_side_effect}; + attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation, + &to_apply}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1636,9 +1639,17 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, shape, operands, *custom_call_target, *operand_layout_constraints, backend_config ? *backend_config : "")); } else { - instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( - shape, operands, *custom_call_target, - backend_config ? *backend_config : "")); + if (to_apply.has_value()) { + instruction = + builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *to_apply, *custom_call_target, + backend_config ? *backend_config : "")); + } else { + instruction = + builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, + backend_config ? *backend_config : "")); + } } auto custom_call_instr = Cast(instruction); if (window.has_value()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index e0dacfb8e00..6e87a95a14e 100755 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -46,6 +46,7 @@ bool IsCallerInstruction(HloInstruction* hlo) { case HloOpcode::kSelectAndScatter: case HloOpcode::kSort: case HloOpcode::kFusion: + case HloOpcode::kCustomCall: return true; default: return false;