Add AddOpWithShape virtual method that only has shape, op name and operands fields
This method can be used by ops that don't have any other fields in the instruction proto and allows derived class to easily override these. For now, use it for UnaryOp, TernaryOp and ConvertElementType ops. PiperOrigin-RevId: 300229493 Change-Id: Ic8790523ce78a995716a34610b59f061e95fcf2a
This commit is contained in:
parent
8d2178ea81
commit
cf315d3e75
@ -499,12 +499,10 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
|
||||
|
||||
XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), unop, {operand});
|
||||
return AddOpWithShape(unop, shape, {operand});
|
||||
});
|
||||
}
|
||||
|
||||
@ -592,7 +590,6 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(
|
||||
|
||||
XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
XlaOp updated_lhs = lhs;
|
||||
XlaOp updated_rhs = rhs;
|
||||
XlaOp updated_ehs = ehs;
|
||||
@ -645,8 +642,8 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
|
||||
"%s Input scalar shapes may have been changed to non-scalar shapes.",
|
||||
status_or_shape.status().error_message());
|
||||
}
|
||||
*instr.mutable_shape() = status_or_shape.ConsumeValueOrDie().ToProto();
|
||||
return AddInstruction(std::move(instr), triop,
|
||||
|
||||
return AddOpWithShape(triop, status_or_shape.ValueOrDie(),
|
||||
{updated_lhs, updated_rhs, updated_ehs});
|
||||
});
|
||||
}
|
||||
@ -1626,12 +1623,10 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
|
||||
XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
|
||||
PrimitiveType new_element_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
|
||||
*operand_shape, new_element_type));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
|
||||
return AddOpWithShape(HloOpcode::kConvert, shape, {operand});
|
||||
});
|
||||
}
|
||||
|
||||
@ -2815,6 +2810,13 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
|
||||
return op;
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape,
|
||||
absl::Span<const XlaOp> operands) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), opcode, operands);
|
||||
}
|
||||
|
||||
void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
|
||||
HloInstructionProto* instr) {
|
||||
absl::flat_hash_map<int64, int64> remapped_ids;
|
||||
|
@ -1075,6 +1075,10 @@ class XlaBuilder {
|
||||
absl::Span<const XlaComputation* const> branch_computations,
|
||||
absl::Span<const XlaOp> branch_operands);
|
||||
|
||||
// Creates an op with the given opcode and the output shape.
|
||||
virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
|
||||
absl::Span<const XlaOp> operands);
|
||||
|
||||
// Here, InstructionType is either const HloInstructionProto* or non-const
|
||||
// HloInstructionProto*.
|
||||
template <typename InstructionType>
|
||||
|
Loading…
Reference in New Issue
Block a user