Add MakeUnaryHlo() and MakeReverseHlo() to hlo_creation_utils.h/.cc
PiperOrigin-RevId: 296080049 Change-Id: I81d020a76da6820086a1a50379c77efc6c43918c
This commit is contained in:
parent
14ff7577c9
commit
2b95bfb6d8
@ -33,6 +33,15 @@ limitations under the License.
|
||||
namespace xla {
|
||||
using absl::StrCat;
|
||||
|
||||
StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
|
||||
HloInstruction* operand) {
|
||||
HloComputation* computation = operand->parent();
|
||||
TF_ASSIGN_OR_RETURN(Shape unary_op_shape,
|
||||
ShapeInference::InferUnaryOpShape(opcode, operand));
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateUnary(unary_op_shape, opcode, operand));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
HloComputation* computation = lhs->parent();
|
||||
@ -344,6 +353,15 @@ StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
|
||||
scalar_shape, operand, init_value, all_dims, reduce_computation));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
|
||||
absl::Span<const int64> dimensions) {
|
||||
HloComputation* computation = operand->parent();
|
||||
TF_ASSIGN_OR_RETURN(Shape reverse_shape, ShapeInference::InferReverseShape(
|
||||
operand->shape(), dimensions));
|
||||
return computation->AddInstruction(
|
||||
HloInstruction::CreateReverse(reverse_shape, operand, dimensions));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
|
||||
HloInstruction* on_true,
|
||||
HloInstruction* on_false,
|
||||
|
@ -27,6 +27,11 @@ namespace xla {
|
||||
// ergonomic. We don't have a complete set of helpers yet -- I expect we'll
|
||||
// expand this interface as needed on an ad-hoc basis.
|
||||
|
||||
// Creates a unary HLO instruction and adds it to the computation containing
|
||||
// `operand`.
|
||||
StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
|
||||
HloInstruction* operand);
|
||||
|
||||
// Creates a binary HLO instruction and adds it to the computation containing
|
||||
// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
|
||||
StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
|
||||
@ -145,6 +150,11 @@ StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
|
||||
HloOpcode binary_opcode,
|
||||
HloModule* module);
|
||||
|
||||
// Creates a Reverse HLO instruction and adds it to the computation containing
|
||||
// `operand`.
|
||||
StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
|
||||
absl::Span<const int64> dimensions);
|
||||
|
||||
// Creates a Select HLO instruction and adds it to the computation containing
|
||||
// the predicate. The on_true and on_false instructions must also be contained
|
||||
// in the same computation. If on_true and on_false are tuples, create a tuple
|
||||
|
Loading…
x
Reference in New Issue
Block a user