[NFC] Migrate more code to use ranges instead of ArrayRef

PiperOrigin-RevId: 333521546
Change-Id: I121d5c4e7e3a0383bdc4c0cb531df38da89a7d85
This commit is contained in:
Rahul Joshi 2020-09-24 08:42:34 -07:00 committed by TensorFlower Gardener
parent 020f426c0e
commit e930518aa5
8 changed files with 34 additions and 63 deletions

View File

@ -504,13 +504,12 @@ LogicalResult Verify(ReplicateOp op) {
return success();
}
template <typename OperandsTy, typename ResultsTy>
void BuildReplicateOp(
Builder* builder, OperationState* state, int n,
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
devices,
llvm::ArrayRef<std::pair<OperandsTy, Type>> replicated_inputs,
llvm::ArrayRef<Value> packed_inputs, ResultsTy replica_output_types) {
llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
ValueRange packed_inputs, TypeRange replica_output_types) {
DCHECK_GE(n, 2);
state->addAttribute("n", builder->getI32IntegerAttr(n));
@ -538,7 +537,7 @@ void BuildReplicateOp(
block.addArgument(replicated_input.second);
}
for (auto& packed_input : packed_inputs) {
for (auto packed_input : packed_inputs) {
state->addOperands(packed_input);
block.addArgument(packed_input.getType());
}
@ -560,20 +559,8 @@ void ReplicateOp::build(
OpBuilder& builder, OperationState& state, int n,
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
devices,
llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs,
llvm::ArrayRef<Value> packed_inputs,
llvm::ArrayRef<Type> replica_output_types) {
BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
packed_inputs, replica_output_types);
}
void ReplicateOp::build(
OpBuilder& builder, OperationState& state, int n,
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
devices,
llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs,
llvm::ArrayRef<Value> packed_inputs,
Operation::result_type_range replica_output_types) {
llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs,
ValueRange packed_inputs, TypeRange replica_output_types) {
BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
packed_inputs, replica_output_types);
}

View File

@ -295,14 +295,8 @@ For example:
let builders = [
OpBuilder<"OpBuilder& builder, OperationState& state, int n, "
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
"llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs, "
"llvm::ArrayRef<Value> packed_inputs, "
"llvm::ArrayRef<Type> replica_output_types">,
OpBuilder<"OpBuilder& builder, OperationState& state, int n, "
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
"llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs, "
"llvm::ArrayRef<Value> packed_inputs, "
"Operation::result_type_range replica_output_types">
"llvm::ArrayRef<std::pair<ValueRange, Type>> replicated_inputs, "
"ValueRange packed_inputs, TypeRange replica_output_types">,
];
let parser = [{ return Parse$cppClass(&parser, &result); }];

View File

@ -67,7 +67,7 @@ using tensorflow::shape_inference::ShapeHandle;
namespace mlir {
namespace TF {
namespace {
Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
Optional<TypeRange> InferShapeForFunctionReturnType(FuncOp func) {
// Find any return ops.
SmallVector<ReturnOp, 4> return_ops;
for (Block& block : func) {
@ -110,7 +110,7 @@ Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
}
}
return llvm::to_vector<4>(return_op.getOperandTypes());
return TypeRange(return_op.getOperandTypes());
}
// Returns if the shape inference pass supports an op outside the TF dialect.
@ -805,7 +805,6 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
ModuleOp module, Operation::operand_type_range input_types,
ArrayRef<FuncOp> functions, int64_t max_iteration) {
bool all_succeeded = true;
auto types = llvm::to_vector<4>(input_types);
// If shape propagation fails for one function, return failure, but do not
// early exit and attempt to propagate shapes for all provided functions to
// have a best-effort propagation.
@ -822,8 +821,8 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
}
FunctionType func_type = func.getType();
func.setType(
FunctionType::get(types, func_type.getResults(), func.getContext()));
func.setType(FunctionType::get(input_types, func_type.getResults(),
func.getContext()));
auto res =
PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration);
@ -834,7 +833,7 @@ LogicalResult ShapeInference::PropagateShapeToFunctions(
auto new_return_types = InferShapeForFunctionReturnType(func);
if (new_return_types)
func.setType(FunctionType::get(types, new_return_types.getValue(),
func.setType(FunctionType::get(input_types, new_return_types.getValue(),
func.getContext()));
}
return success(all_succeeded);
@ -844,16 +843,17 @@ LogicalResult ShapeInference::PropagateShapeToRegions(
Operation::operand_type_range input_types, ArrayRef<Region*> regions,
int64_t max_iteration) {
bool all_succeeded = true;
auto types = llvm::to_vector<4>(input_types);
// If shape propagation fails for one region, return failure, but do not
// early exit and attempt to propagate shapes for all provided regions to
// have a best-effort propagation.
for (auto region : regions) {
// Refine region arguments.
Block& entry = region->front();
assert(types.size() == entry.getNumArguments());
for (auto arg_and_idx : llvm::enumerate(entry.getArguments())) {
arg_and_idx.value().setType(types[arg_and_idx.index()]);
assert(llvm::size(input_types) == entry.getNumArguments());
for (auto it : llvm::zip(entry.getArguments(), input_types)) {
BlockArgument arg = std::get<0>(it);
Type type = std::get<1>(it);
arg.setType(type);
}
// Propagate shapes into the region.

View File

@ -73,8 +73,7 @@ void UpdateFuncType(FuncOp func) {
llvm::SmallVector<Type, 8> arg_types;
for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
func.setType(FunctionType::get(
arg_types,
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
arg_types, func.front().getTerminator()->getOperandTypes(),
func.getContext()));
}

View File

@ -377,8 +377,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
// Check if number of operands of each used TPUReplicatedInput op matches
// `num_replicas` or 1. Collect all their operands and associated type for
// creating the replicate op.
llvm::SmallVector<std::pair<Operation::operand_range, Type>, 8>
replicated_inputs;
llvm::SmallVector<std::pair<ValueRange, Type>, 8> replicated_inputs;
llvm::SmallVector<Value, 8> packed_inputs;
for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
auto input = pos_and_input.value();

View File

@ -283,8 +283,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
->getSecond()
.size() == num_replicas);
llvm::SmallVector<std::pair<llvm::ArrayRef<Value>, Type>, 8>
new_replicated_inputs;
llvm::SmallVector<std::pair<ValueRange, Type>, 8> new_replicated_inputs;
llvm::SmallVector<Value, 8> new_packed_inputs;
llvm::SmallVector<llvm::SmallVector<Value, 8>, 8> replicated_inputs;
replicated_inputs.reserve(replicate.GetNumReplicatedBlockArguments());
@ -310,8 +309,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp(
auto new_replicate = builder.create<tf_device::ReplicateOp>(
replicate.getLoc(), num_replicas, devices, new_replicated_inputs,
new_packed_inputs,
llvm::to_vector<8>(
replicate.GetBody().getTerminator()->getOperandTypes()));
replicate.GetBody().getTerminator()->getOperandTypes());
for (auto arg : replicate.GetBody().getArguments()) {
if (replicate.IsReplicatedBlockArgument(arg)) {
arg.replaceAllUsesWith(
@ -464,8 +462,7 @@ void HandleReplicateOp(TF::WhileRegionOp while_op,
// Build the replicated unformat op after the loop. First prepare building the
// replicate op.
llvm::SmallVector<std::pair<llvm::ArrayRef<Value>, Type>, 8>
unformat_replicate_operands;
llvm::SmallVector<std::pair<ValueRange, Type>, 8> unformat_replicate_operands;
llvm::SmallVector<Value, 8> unformat_packed_operands;
for (const auto& entry : execute_arg_to_outer_args) {
if (entry.second.size() > 1) {
@ -495,7 +492,7 @@ void HandleReplicateOp(TF::WhileRegionOp while_op,
// With all replicated inputs, now build the replicate op.
auto unformat_replicate = builder.create<tf_device::ReplicateOp>(
while_op.getLoc(), num_replicas, devices, unformat_replicate_operands,
unformat_packed_operands, ArrayRef<Type>{});
unformat_packed_operands, TypeRange{});
// Then build the unformat op in the replicate op.
builder.setInsertionPointToEnd(&unformat_replicate.GetBody());
llvm::SmallVector<Value, 8> unformat_operands;

View File

@ -130,18 +130,15 @@ void PopulateEmptyIsland(tf_executor::IslandOp island) {
OpBuilder builder(&island.GetBody(), island.GetBody().begin());
tf_executor::YieldOp yield = island.GetYield();
if (yield.getNumOperands() == 0) {
builder.create<TF::NoOp>(island.getLoc(), llvm::ArrayRef<mlir::Type>{},
llvm::ArrayRef<mlir::Value>{},
llvm::ArrayRef<mlir::NamedAttribute>{});
builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
} else if (yield.getNumOperands() == 1) {
Value operand = yield.getOperand(0);
auto identity = builder.create<TF::IdentityOp>(island.getLoc(),
operand.getType(), operand);
yield.setOperand(0, identity.output());
} else {
auto types = llvm::to_vector<4>(yield.getOperandTypes());
auto identity_n = builder.create<TF::IdentityNOp>(island.getLoc(), types,
yield.getOperands());
auto identity_n = builder.create<TF::IdentityNOp>(
island.getLoc(), yield.getOperandTypes(), yield.getOperands());
for (auto it : llvm::enumerate(identity_n.getResults()))
yield.setOperand(it.index(), it.value());
}
@ -149,8 +146,8 @@ void PopulateEmptyIsland(tf_executor::IslandOp island) {
// Helper that creates an island. If `sub_op` is not nullptr, it will be moved
// to the island. Otherwise a NoOp will be added to the island.
tf_executor::IslandOp CreateIsland(ArrayRef<Type> result_types,
ArrayRef<Value> control_inputs,
tf_executor::IslandOp CreateIsland(TypeRange result_types,
ValueRange control_inputs,
const tf_executor::ControlType& control_type,
const Location& loc, Operation* sub_op,
tf_executor::IslandOp original_island) {
@ -166,10 +163,8 @@ tf_executor::IslandOp CreateIsland(ArrayRef<Type> result_types,
sub_op->moveBefore(block, block->begin());
island_builder.create<tf_executor::YieldOp>(loc, sub_op->getResults());
} else {
island_builder.create<TF::NoOp>(
island.getLoc(), llvm::ArrayRef<mlir::Type>{},
llvm::ArrayRef<mlir::Value>{}, llvm::ArrayRef<mlir::NamedAttribute>{});
island_builder.create<tf_executor::YieldOp>(loc, ArrayRef<Value>{});
island_builder.create<TF::NoOp>(island.getLoc(), TypeRange{}, ValueRange{});
island_builder.create<tf_executor::YieldOp>(loc, ValueRange{});
}
return island;
}
@ -282,8 +277,8 @@ void BreakUpIslands::BreakUpIsland(
? island_control_inputs
: predecessor_controls;
auto new_island =
CreateIsland(llvm::to_vector<4>(sub_op.getResultTypes()), control,
control_type, sub_op.getLoc(), &sub_op, island_op);
CreateIsland(sub_op.getResultTypes(), control, control_type,
sub_op.getLoc(), &sub_op, island_op);
new_control_for_sub_ops[&sub_op] = new_island.control();
if (sources_and_sinks.sinks.count(&sub_op)) {
sink_island_controls.push_back(new_island.control());

View File

@ -760,8 +760,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
devices;
auto replicate = builder.create<mlir::tf_device::ReplicateOp>(
mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices,
llvm::ArrayRef<std::pair<llvm::ArrayRef<mlir::Value>, mlir::Type>>{},
llvm::ArrayRef<mlir::Value>{}, llvm::ArrayRef<mlir::Type>{});
llvm::ArrayRef<std::pair<mlir::ValueRange, mlir::Type>>{},
mlir::ValueRange{}, mlir::TypeRange{});
builder.setInsertionPoint(&replicate.body().front(),
replicate.body().front().begin());