From cf315d3e75a45efc984a2636d222c75185b85c70 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Tue, 10 Mar 2020 19:11:02 -0700 Subject: [PATCH] Add AddOpWithShape virtual method that only has shape, op name and operands fields This method can be used by ops that don't have any other fields in the instruction proto and allows derived class to easily override these. For now, use it for UnaryOp, TernaryOp and ConvertElementType ops. PiperOrigin-RevId: 300229493 Change-Id: Ic8790523ce78a995716a34610b59f061e95fcf2a --- tensorflow/compiler/xla/client/xla_builder.cc | 20 ++++++++++--------- tensorflow/compiler/xla/client/xla_builder.h | 4 ++++ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 59c78366da2..3939f912fa2 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -499,12 +499,10 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape)); - *instr.mutable_shape() = shape.ToProto(); - return AddInstruction(std::move(instr), unop, {operand}); + return AddOpWithShape(unop, shape, {operand}); }); } @@ -592,7 +590,6 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast( XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; XlaOp updated_lhs = lhs; XlaOp updated_rhs = rhs; XlaOp updated_ehs = ehs; @@ -645,8 +642,8 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) { "%s Input scalar shapes may have been changed to non-scalar shapes.", status_or_shape.status().error_message()); } - *instr.mutable_shape() = status_or_shape.ConsumeValueOrDie().ToProto(); - return AddInstruction(std::move(instr), triop, + + return AddOpWithShape(triop, status_or_shape.ValueOrDie(), {updated_lhs, updated_rhs, updated_ehs}); }); } @@ -1626,12 +1623,10 @@ XlaOp XlaBuilder::Sort(absl::Span operands, XlaOp XlaBuilder::ConvertElementType(XlaOp operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( *operand_shape, new_element_type)); - *instr.mutable_shape() = shape.ToProto(); - return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand}); + return AddOpWithShape(HloOpcode::kConvert, shape, {operand}); }); } @@ -2815,6 +2810,13 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, return op; } +StatusOr XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape, + absl::Span operands) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return AddInstruction(std::move(instr), opcode, operands); +} + void XlaBuilder::AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr) { absl::flat_hash_map remapped_ids; diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index f4c61462928..9d03141715f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -1075,6 +1075,10 @@ class XlaBuilder { absl::Span branch_computations, absl::Span branch_operands); + // Creates an op with the given opcode and the output shape. + virtual StatusOr AddOpWithShape(HloOpcode opcode, const Shape& shape, + absl::Span operands); + // Here, InstructionType is either const HloInstructionProto* or non-const // HloInstructionProto*. template