[NFC] Migrate more code to use ranges instead of ArrayRef
PiperOrigin-RevId: 333521546 Change-Id: I121d5c4e7e3a0383bdc4c0cb531df38da89a7d85
This commit is contained in:
parent
020f426c0e
commit
e930518aa5
@ -504,13 +504,12 @@ LogicalResult Verify(ReplicateOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename OperandsTy, typename ResultsTy>
|
||||
void BuildReplicateOp(
|
||||
Builder* builder, OperationState* state, int n,
|
||||
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
|
||||
devices,
|
||||
llvm::ArrayRef<std::pair<OperandsTy, Type>> replicated_inputs,
|
||||
llvm::ArrayRef<Value> packed_inputs, ResultsTy replica_output_types) {
|
||||
llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
|
||||
ValueRange packed_inputs, TypeRange replica_output_types) {
|
||||
DCHECK_GE(n, 2);
|
||||
state->addAttribute("n", builder->getI32IntegerAttr(n));
|
||||
|
||||
@ -538,7 +537,7 @@ void BuildReplicateOp(
|
||||
block.addArgument(replicated_input.second);
|
||||
}
|
||||
|
||||
for (auto& packed_input : packed_inputs) {
|
||||
for (auto packed_input : packed_inputs) {
|
||||
state->addOperands(packed_input);
|
||||
block.addArgument(packed_input.getType());
|
||||
}
|
||||
@ -560,20 +559,8 @@ void ReplicateOp::build(
|
||||
OpBuilder& builder, OperationState& state, int n,
|
||||
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
|
||||
devices,
|
||||
llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs,
|
||||
llvm::ArrayRef<Value> packed_inputs,
|
||||
llvm::ArrayRef<Type> replica_output_types) {
|
||||
BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
|
||||
packed_inputs, replica_output_types);
|
||||
}
|
||||
|
||||
void ReplicateOp::build(
|
||||
OpBuilder& builder, OperationState& state, int n,
|
||||
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
|
||||
devices,
|
||||
llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs,
|
||||
llvm::ArrayRef<Value> packed_inputs,
|
||||
Operation::result_type_range replica_output_types) {
|
||||
llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
|
||||
ValueRange packed_inputs, TypeRange replica_output_types) {
|
||||
BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
|
||||
packed_inputs, replica_output_types);
|
||||
}
|
||||
|
@ -295,14 +295,8 @@ For example:
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder& builder, OperationState& state, int n, "
|
||||
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
|
||||
"llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs, "
|
||||
"llvm::ArrayRef<Value> packed_inputs, "
|
||||
"llvm::ArrayRef<Type> replica_output_types">,
|
||||
OpBuilder<"OpBuilder& builder, OperationState& state, int n, "
|
||||
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
|
||||
"llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs, "
|
||||
"llvm::ArrayRef<Value> packed_inputs, "
|
||||
"Operation::result_type_range replica_output_types">
|
||||
"llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs, "
|
||||
"ValueRange packed_inputs, TypeRange replica_output_types">,
|
||||
];
|
||||
|
||||
let parser = [{ return Parse$cppClass(&parser, &result); }];
|
||||
|
@ -67,7 +67,7 @@ using tensorflow::shape_inference::ShapeHandle;
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
|
||||
Optional<TypeRange> InferShapeForFunctionReturnType(FuncOp func) {
|
||||
// Find any return ops.
|
||||
SmallVector<ReturnOp, 4> return_ops;
|
||||
for (Block& block : func) {
|
||||
@ -110,7 +110,7 @@ Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
|
||||
}
|
||||
}
|
||||
|
||||
return llvm::to_vector<4>(return_op.getOperandTypes());
|
||||
return TypeRange(return_op.getOperandTypes());
|
||||
}
|
||||
|
||||
// Returns if the shape inference pass supports an op outside the TF dialect.
|
||||
@ -805,7 +805,6 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
|
||||
ModuleOp module, Operation::operand_type_range input_types,
|
||||
ArrayRef<FuncOp> functions, int64_t max_iteration) {
|
||||
bool all_succeeded = true;
|
||||
auto types = llvm::to_vector<4>(input_types);
|
||||
// If shape propagation fails for one function, return failure, but do not
|
||||
// early exit and attempt to propagate shapes for all provided functions to
|
||||
// have a best-effort propagation.
|
||||
@ -822,8 +821,8 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
|
||||
}
|
||||
|
||||
FunctionType func_type = func.getType();
|
||||
func.setType(
|
||||
FunctionType::get(types, func_type.getResults(), func.getContext()));
|
||||
func.setType(FunctionType::get(input_types, func_type.getResults(),
|
||||
func.getContext()));
|
||||
|
||||
auto res =
|
||||
PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration);
|
||||
@ -834,7 +833,7 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
|
||||
|
||||
auto new_return_types = InferShapeForFunctionReturnType(func);
|
||||
if (new_return_types)
|
||||
func.setType(FunctionType::get(types, new_return_types.getValue(),
|
||||
func.setType(FunctionType::get(input_types, new_return_types.getValue(),
|
||||
func.getContext()));
|
||||
}
|
||||
return success(all_succeeded);
|
||||
@ -844,16 +843,17 @@ LogicalResult ShapeInference::PropagateShapeToRegions(
|
||||
Operation::operand_type_range input_types, ArrayRef<Region*> regions,
|
||||
int64_t max_iteration) {
|
||||
bool all_succeeded = true;
|
||||
auto types = llvm::to_vector<4>(input_types);
|
||||
// If shape propagation fails for one region, return failure, but do not
|
||||
// early exit and attempt to propagate shapes for all provided regions to
|
||||
// have a best-effort propagation.
|
||||
for (auto region : regions) {
|
||||
// Refine region arguments.
|
||||
Block& entry = region->front();
|
||||
assert(types.size() == entry.getNumArguments());
|
||||
for (auto arg_and_idx : llvm::enumerate(entry.getArguments())) {
|
||||
arg_and_idx.value().setType(types[arg_and_idx.index()]);
|
||||
assert(llvm::size(input_types) == entry.getNumArguments());
|
||||
for (auto it : llvm::zip(entry.getArguments(), input_types)) {
|
||||
BlockArgument arg = std::get<0>(it);
|
||||
Type type = std::get<1>(it);
|
||||
arg.setType(type);
|
||||
}
|
||||
|
||||
// Propagate shapes into the region.
|
||||
|
@ -73,8 +73,7 @@ void UpdateFuncType(FuncOp func) {
|
||||
llvm::SmallVector<Type, 8> arg_types;
|
||||
for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
|
||||
func.setType(FunctionType::get(
|
||||
arg_types,
|
||||
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
|
||||
arg_types, func.front().getTerminator()->getOperandTypes(),
|
||||
func.getContext()));
|
||||
}
|
||||
|
||||
|
@ -377,8 +377,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
||||
// Check if number of operands of each used TPUReplicatedInput op matches
|
||||
// `num_replicas` or 1. Collect all their operands and associated type for
|
||||
// creating the replicate op.
|
||||
llvm::SmallVector<std::pair<Operation::operand_range, Type>, 8>
|
||||
replicated_inputs;
|
||||
llvm::SmallVector<std::pair<ValueRange, Type>, 8> replicated_inputs;
|
||||
llvm::SmallVector<Value, 8> packed_inputs;
|
||||
for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
|
||||
auto input = pos_and_input.value();
|
||||
|
@ -283,8 +283,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
|
||||
->getSecond()
|
||||
.size() == num_replicas);
|
||||
|
||||
llvm::SmallVector<std::pair<llvm::ArrayRef<Value>, Type>, 8>
|
||||
new_replicated_inputs;
|
||||
llvm::SmallVector<std::pair<ValueRange, Type>, 8> new_replicated_inputs;
|
||||
llvm::SmallVector<Value, 8> new_packed_inputs;
|
||||
llvm::SmallVector<llvm::SmallVector<Value, 8>, 8> replicated_inputs;
|
||||
replicated_inputs.reserve(replicate.GetNumReplicatedBlockArguments());
|
||||
@ -310,8 +309,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
|
||||
auto new_replicate = builder.create<tf_device::ReplicateOp>(
|
||||
replicate.getLoc(), num_replicas, devices, new_replicated_inputs,
|
||||
new_packed_inputs,
|
||||
llvm::to_vector<8>(
|
||||
replicate.GetBody().getTerminator()->getOperandTypes()));
|
||||
replicate.GetBody().getTerminator()->getOperandTypes());
|
||||
for (auto arg : replicate.GetBody().getArguments()) {
|
||||
if (replicate.IsReplicatedBlockArgument(arg)) {
|
||||
arg.replaceAllUsesWith(
|
||||
@ -464,8 +462,7 @@ void HandleReplicateOp(TF::WhileRegionOp while_op,
|
||||
|
||||
// Build the replicated unformat op after the loop. First prepare building the
|
||||
// replicate op.
|
||||
llvm::SmallVector<std::pair<llvm::ArrayRef<Value>, Type>, 8>
|
||||
unformat_replicate_operands;
|
||||
llvm::SmallVector<std::pair<ValueRange, Type>, 8> unformat_replicate_operands;
|
||||
llvm::SmallVector<Value, 8> unformat_packed_operands;
|
||||
for (const auto& entry : execute_arg_to_outer_args) {
|
||||
if (entry.second.size() > 1) {
|
||||
@ -495,7 +492,7 @@ void HandleReplicateOp(TF::WhileRegionOp while_op,
|
||||
// With all replicated inputs, now build the replicate op.
|
||||
auto unformat_replicate = builder.create<tf_device::ReplicateOp>(
|
||||
while_op.getLoc(), num_replicas, devices, unformat_replicate_operands,
|
||||
unformat_packed_operands, ArrayRef<Type>{});
|
||||
unformat_packed_operands, TypeRange{});
|
||||
// Then build the unformat op in the replicate op.
|
||||
builder.setInsertionPointToEnd(&unformat_replicate.GetBody());
|
||||
llvm::SmallVector<Value, 8> unformat_operands;
|
||||
|
@ -130,18 +130,15 @@ void PopulateEmptyIsland(tf_executor::IslandOp island) {
|
||||
OpBuilder builder(&island.GetBody(), island.GetBody().begin());
|
||||
tf_executor::YieldOp yield = island.GetYield();
|
||||
if (yield.getNumOperands() == 0) {
|
||||
builder.create<TF::NoOp>(island.getLoc(), llvm::ArrayRef<mlir::Type>{},
|
||||
llvm::ArrayRef<mlir::Value>{},
|
||||
llvm::ArrayRef<mlir::NamedAttribute>{});
|
||||
builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
|
||||
} else if (yield.getNumOperands() == 1) {
|
||||
Value operand = yield.getOperand(0);
|
||||
auto identity = builder.create<TF::IdentityOp>(island.getLoc(),
|
||||
operand.getType(), operand);
|
||||
yield.setOperand(0, identity.output());
|
||||
} else {
|
||||
auto types = llvm::to_vector<4>(yield.getOperandTypes());
|
||||
auto identity_n = builder.create<TF::IdentityNOp>(island.getLoc(), types,
|
||||
yield.getOperands());
|
||||
auto identity_n = builder.create<TF::IdentityNOp>(
|
||||
island.getLoc(), yield.getOperandTypes(), yield.getOperands());
|
||||
for (auto it : llvm::enumerate(identity_n.getResults()))
|
||||
yield.setOperand(it.index(), it.value());
|
||||
}
|
||||
@ -149,8 +146,8 @@ void PopulateEmptyIsland(tf_executor::IslandOp island) {
|
||||
|
||||
// Helper that creates an island. If `sub_op` is not nullptr, it will be moved
|
||||
// to the island. Otherwise a NoOp will be added to the island.
|
||||
tf_executor::IslandOp CreateIsland(ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> control_inputs,
|
||||
tf_executor::IslandOp CreateIsland(TypeRange result_types,
|
||||
ValueRange control_inputs,
|
||||
const tf_executor::ControlType& control_type,
|
||||
const Location& loc, Operation* sub_op,
|
||||
tf_executor::IslandOp original_island) {
|
||||
@ -166,10 +163,8 @@ tf_executor::IslandOp CreateIsland(ArrayRef<Type> result_types,
|
||||
sub_op->moveBefore(block, block->begin());
|
||||
island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
|
||||
} else {
|
||||
island_builder.create<TF::NoOp>(
|
||||
island.getLoc(), llvm::ArrayRef<mlir::Type>{},
|
||||
llvm::ArrayRef<mlir::Value>{}, llvm::ArrayRef<mlir::NamedAttribute>{});
|
||||
island_builder.create<tf_executor::YieldOp>(loc, ArrayRef<Value>{});
|
||||
island_builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
|
||||
island_builder.create<tf_executor::YieldOp>(loc, ValueRange{});
|
||||
}
|
||||
return island;
|
||||
}
|
||||
@ -282,8 +277,8 @@ void BreakUpIslands::BreakUpIsland(
|
||||
? island_control_inputs
|
||||
: predecessor_controls;
|
||||
auto new_island =
|
||||
CreateIsland(llvm::to_vector<4>(sub_op.getResultTypes()), control,
|
||||
control_type, sub_op.getLoc(), &sub_op, island_op);
|
||||
CreateIsland(sub_op.getResultTypes(), control, control_type,
|
||||
sub_op.getLoc(), &sub_op, island_op);
|
||||
new_control_for_sub_ops[&sub_op] = new_island.control();
|
||||
if (sources_and_sinks.sinks.count(&sub_op)) {
|
||||
sink_island_controls.push_back(new_island.control());
|
||||
|
@ -760,8 +760,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
|
||||
devices;
|
||||
auto replicate = builder.create<mlir::tf_device::ReplicateOp>(
|
||||
mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices,
|
||||
llvm::ArrayRef<std::pair<llvm::ArrayRef<mlir::Value>, mlir::Type>>{},
|
||||
llvm::ArrayRef<mlir::Value>{}, llvm::ArrayRef<mlir::Type>{});
|
||||
llvm::ArrayRef<std::pair<mlir::ValueRange, mlir::Type>>{},
|
||||
mlir::ValueRange{}, mlir::TypeRange{});
|
||||
builder.setInsertionPoint(&replicate.body().front(),
|
||||
replicate.body().front().begin());
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user