[XLA:GPU][NFC] Reorder arguments of CreateOperands.

- Move all output values at end and remove default values for num_operands and
  result_subset as they are not needed.
- Use pre-increment when possible to conform to Google coding style.

PiperOrigin-RevId: 350179475
Change-Id: I6f7f8a5e6daf3b71b9efb083e9eb5083ad95a51a
This commit is contained in:
Rahul Joshi 2021-01-05 11:13:12 -08:00 committed by TensorFlower Gardener
parent 2c687ec4d9
commit fea4a23847
2 changed files with 21 additions and 20 deletions

View File

@ -202,14 +202,16 @@ class XlaHloToLhloPass
} // namespace
// Creates MLIR operands corresponding to operands and results of the XLA HLO
// instruction. If `num_operands` is not -1, then only the first `num_operands`
// instruction. If `num_operands` is valid, then only the first `num_operands`
// operands of the HLO instruction will be considered.
Status LhloDialectEmitter::CreateOperands(
HloInstruction* instr, llvm::SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results,
absl::optional<xla::int64> num_operands) {
HloInstruction* instr, absl::optional<xla::int64> num_operands,
llvm::SmallVectorImpl<Value>& operands, size_t& num_arguments,
size_t& num_results) {
if (num_operands.value_or(0) > instr->operand_count())
return xla::InvalidArgument("num_operands must be <= operand count");
for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
i++) {
++i) {
TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands));
}
num_arguments = operands.size();
@ -232,8 +234,8 @@ StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
HloInstruction* instr, size_t& num_arguments, size_t& num_results,
absl::optional<xla::int64> num_operands) {
llvm::SmallVector<Value, 4> operands;
TF_RETURN_IF_ERROR(CreateOperands(instr, operands, num_arguments, num_results,
num_operands));
TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, operands,
num_arguments, num_results));
return CreateOpWithoutAttrs<OpType>(instr, operands);
}
@ -396,7 +398,7 @@ StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
::xla::ShapeIndex* shape_index, OpBuilder* b, Location loc) {
if (shape.IsTuple()) {
llvm::SmallVector<Value, 4> values;
for (int i = 0; i < shape.tuple_shapes_size(); i++) {
for (int i = 0; i < shape.tuple_shapes_size(); ++i) {
shape_index->push_back(i);
TF_ASSIGN_OR_RETURN(
auto v, RewriteFusionOperand(root, shape.tuple_shapes(i), shape_index,
@ -432,7 +434,7 @@ StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
auto region_builder = OpBuilder::atBlockBegin(&fusion.region().front());
llvm::SmallVector<Value, 8> arguments;
for (int i = 0; i < instr->operands().size(); i++) {
for (int i = 0; i < instr->operands().size(); ++i) {
const HloInstruction* operand = instr->operand(i);
xla::ShapeIndex shape_index;
TF_ASSIGN_OR_RETURN(
@ -1077,7 +1079,7 @@ Status LhloDialectEmitter::GetOrCreateViewImpl(
const HloInstruction* instr, const Shape& current_shape,
::xla::ShapeIndex* current_shape_index, SmallVectorImpl<Value>* values) {
if (current_shape.IsTuple()) {
for (int i = 0; i < current_shape.tuple_shapes().size(); i++) {
for (int i = 0; i < current_shape.tuple_shapes().size(); ++i) {
current_shape_index->push_back(i);
TF_RETURN_IF_ERROR(GetOrCreateViewImpl(
instr, current_shape.tuple_shapes(i), current_shape_index, values));

View File

@ -91,16 +91,15 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::HloInstruction* instr);
// Create LHLO operation operands given an XLA HLO instruction. By default,
// all XLA HLO operands and results are converted to MLIR. If `num_operands`
// is specified, only the first `num_operand` operands of the instruction are
// converted to MLIR. The function returns the actual number of operands and
// results generated for MLIR in `num_arguments` and `num_results`.
// TODO(jurahul): Move all function outputs (operands, num_arguments, and
// num_results) to the end of the argument list.
::xla::Status CreateOperands(
::xla::HloInstruction* instr, SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results,
absl::optional<xla::int64> num_operands = absl::nullopt);
// all XLA HLO operands and results are converted to MLIR and appended to
// `operands`. If `num_operands` is specified, only the first `num_operand`
// operands of the instruction are converted to MLIR. The function returns the
// actual number of operands and results generated for MLIR in `num_arguments`
// and `num_results`.
::xla::Status CreateOperands(::xla::HloInstruction* instr,
absl::optional<xla::int64> num_operands,
SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results);
template <typename OpType>
::xla::StatusOr<OpType> CreateOpWithoutAttrs(