Add AddOpWithShape virtual method that only has shape, op name and operands fields

This method can be used by ops that don't have any other fields in the instruction proto and allows derived class to easily override these. For now, use it for UnaryOp, TernaryOp and ConvertElementType ops.

PiperOrigin-RevId: 300229493
Change-Id: Ic8790523ce78a995716a34610b59f061e95fcf2a
This commit is contained in:
Smit Hinsu 2020-03-10 19:11:02 -07:00 committed by TensorFlower Gardener
parent 8d2178ea81
commit cf315d3e75
2 changed files with 15 additions and 9 deletions

View File

@ -499,12 +499,10 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), unop, {operand});
return AddOpWithShape(unop, shape, {operand});
});
}
@ -592,7 +590,6 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(
XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
XlaOp updated_lhs = lhs;
XlaOp updated_rhs = rhs;
XlaOp updated_ehs = ehs;
@ -645,8 +642,8 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
"%s Input scalar shapes may have been changed to non-scalar shapes.",
status_or_shape.status().error_message());
}
*instr.mutable_shape() = status_or_shape.ConsumeValueOrDie().ToProto();
return AddInstruction(std::move(instr), triop,
return AddOpWithShape(triop, status_or_shape.ValueOrDie(),
{updated_lhs, updated_rhs, updated_ehs});
});
}
@ -1626,12 +1623,10 @@ XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
PrimitiveType new_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
*operand_shape, new_element_type));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
return AddOpWithShape(HloOpcode::kConvert, shape, {operand});
});
}
@ -2815,6 +2810,13 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
return op;
}
StatusOr<XlaOp> XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape,
absl::Span<const XlaOp> operands) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), opcode, operands);
}
void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
HloInstructionProto* instr) {
absl::flat_hash_map<int64, int64> remapped_ids;

View File

@ -1075,6 +1075,10 @@ class XlaBuilder {
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands);
// Creates an op with the given opcode and the output shape.
virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
absl::Span<const XlaOp> operands);
// Here, InstructionType is either const HloInstructionProto* or non-const
// HloInstructionProto*.
template <typename InstructionType>