Add AddOpWithShape virtual method that only has shape, op name and operands fields
This method can be used by ops that don't have any other fields in the instruction proto and allows derived class to easily override these. For now, use it for UnaryOp, TernaryOp and ConvertElementType ops. PiperOrigin-RevId: 300229493 Change-Id: Ic8790523ce78a995716a34610b59f061e95fcf2a
This commit is contained in:
parent
8d2178ea81
commit
cf315d3e75
@ -499,12 +499,10 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
|
|||||||
|
|
||||||
XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
|
XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
|
Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
return AddOpWithShape(unop, shape, {operand});
|
||||||
return AddInstruction(std::move(instr), unop, {operand});
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -592,7 +590,6 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(
|
|||||||
|
|
||||||
XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
|
XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
|
||||||
XlaOp updated_lhs = lhs;
|
XlaOp updated_lhs = lhs;
|
||||||
XlaOp updated_rhs = rhs;
|
XlaOp updated_rhs = rhs;
|
||||||
XlaOp updated_ehs = ehs;
|
XlaOp updated_ehs = ehs;
|
||||||
@ -645,8 +642,8 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
|
|||||||
"%s Input scalar shapes may have been changed to non-scalar shapes.",
|
"%s Input scalar shapes may have been changed to non-scalar shapes.",
|
||||||
status_or_shape.status().error_message());
|
status_or_shape.status().error_message());
|
||||||
}
|
}
|
||||||
*instr.mutable_shape() = status_or_shape.ConsumeValueOrDie().ToProto();
|
|
||||||
return AddInstruction(std::move(instr), triop,
|
return AddOpWithShape(triop, status_or_shape.ValueOrDie(),
|
||||||
{updated_lhs, updated_rhs, updated_ehs});
|
{updated_lhs, updated_rhs, updated_ehs});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -1626,12 +1623,10 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
|
|||||||
XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
|
XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
|
||||||
PrimitiveType new_element_type) {
|
PrimitiveType new_element_type) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
|
||||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
|
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
|
||||||
*operand_shape, new_element_type));
|
*operand_shape, new_element_type));
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
return AddOpWithShape(HloOpcode::kConvert, shape, {operand});
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2815,6 +2810,13 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
|
|||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape,
|
||||||
|
absl::Span<const XlaOp> operands) {
|
||||||
|
HloInstructionProto instr;
|
||||||
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
|
return AddInstruction(std::move(instr), opcode, operands);
|
||||||
|
}
|
||||||
|
|
||||||
void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
|
void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
|
||||||
HloInstructionProto* instr) {
|
HloInstructionProto* instr) {
|
||||||
absl::flat_hash_map<int64, int64> remapped_ids;
|
absl::flat_hash_map<int64, int64> remapped_ids;
|
||||||
|
@ -1075,6 +1075,10 @@ class XlaBuilder {
|
|||||||
absl::Span<const XlaComputation* const> branch_computations,
|
absl::Span<const XlaComputation* const> branch_computations,
|
||||||
absl::Span<const XlaOp> branch_operands);
|
absl::Span<const XlaOp> branch_operands);
|
||||||
|
|
||||||
|
// Creates an op with the given opcode and the output shape.
|
||||||
|
virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
|
||||||
|
absl::Span<const XlaOp> operands);
|
||||||
|
|
||||||
// Here, InstructionType is either const HloInstructionProto* or non-const
|
// Here, InstructionType is either const HloInstructionProto* or non-const
|
||||||
// HloInstructionProto*.
|
// HloInstructionProto*.
|
||||||
template <typename InstructionType>
|
template <typename InstructionType>
|
||||||
|
Loading…
Reference in New Issue
Block a user