diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 59c78366da2..3939f912fa2 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -499,12 +499,10 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape)); - *instr.mutable_shape() = shape.ToProto(); - return AddInstruction(std::move(instr), unop, {operand}); + return AddOpWithShape(unop, shape, {operand}); }); } @@ -592,7 +590,6 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast( XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; @@ -645,8 +642,8 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { "%s Input scalar shapes may have been changed to non-scalar shapes.", status_or_shape.status().error_message()); } - *instr.mutable_shape() = status_or_shape.ConsumeValueOrDie().ToProto(); - return AddInstruction(std::move(instr), triop, + + return AddOpWithShape(triop, status_or_shape.ValueOrDie(), {updated_lhs, updated_rhs, updated_ehs}); }); } @@ -1626,12 +1623,10 @@ XlaOp XlaBuilder::Sort(absl::Span operands, XlaOp XlaBuilder::ConvertElementType(XlaOp operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( *operand_shape, new_element_type)); - *instr.mutable_shape() = shape.ToProto(); - return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand}); + return AddOpWithShape(HloOpcode::kConvert, shape, {operand}); }); } @@ -2815,6 +2810,13 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, return op; } +StatusOr XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape, + absl::Span operands) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return AddInstruction(std::move(instr), opcode, operands); +} + void XlaBuilder::AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr) { absl::flat_hash_map remapped_ids; diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index f4c61462928..9d03141715f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -1075,6 +1075,10 @@ class XlaBuilder { absl::Span branch_computations, absl::Span branch_operands); + // Creates an op with the given opcode and the output shape. + virtual StatusOr AddOpWithShape(HloOpcode opcode, const Shape& shape, + absl::Span operands); + // Here, InstructionType is either const HloInstructionProto* or non-const // HloInstructionProto*. template