[NFC] Eliminate use of .getBlocks() when not needed

Also use llvm::hasSingleElement() instead of .size() == 1

PiperOrigin-RevId: 317675565
Change-Id: I4f0e8892957c2b20e115584fe7424da68d53b67a
This commit is contained in:
Rahul Joshi 2020-06-22 09:58:59 -07:00 committed by TensorFlower Gardener
parent 2afae99d7f
commit 251923169d
18 changed files with 41 additions and 40 deletions

View File

@ -240,10 +240,10 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
}
for (auto fn : module.getOps<FuncOp>()) {
if (fn.getBlocks().size() != 1) {
if (!llvm::hasSingleElement(fn)) {
return fn.emitError("should have exactly one basic block"), false;
}
auto& bb = fn.getBlocks().front();
auto& bb = fn.front();
for (auto arg : bb.getArguments()) {
if (!HasValidTFLiteType(arg, fn))
@ -1089,7 +1089,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
/*KeepEmpty=*/false);
auto term = fn.getBlocks().back().getTerminator();
auto term = fn.back().getTerminator();
if (output_names.size() != term->getNumOperands()) {
fn.emitWarning() << "output names (" << output_names.size()
<< ") != terminator operands (" << term->getNumOperands()

View File

@ -52,7 +52,7 @@ class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
void RemoveQuantizationAdaptorOps(FuncOp func) {
mlir::OpBuilder builder(func.getBody());
auto& bb = func.getBlocks().front();
auto& bb = func.front();
auto* terminator = bb.getTerminator();
int num_args = bb.getNumArguments();

View File

@ -299,13 +299,13 @@ ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
parser->parseRegion(body, region_args, region_arg_types))
return failure();
if (body.getBlocks().size() > 1)
return parser->emitError(loc) << "expects a single block region";
// Ensure that the region is well formed: it contains at least a block with
// a ReturnOp terminator.
ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
if (!llvm::hasSingleElement(body))
return parser->emitError(loc) << "expects a single block region";
Operation& terminator = body.front().back();
if (!isa<ReturnOp>(terminator))
return parser->emitError(loc) << "expects a tf_device.return terminator";

View File

@ -220,13 +220,13 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
Region &body = *result.addRegion();
if (parser.parseRegion(body, llvm::None, llvm::None)) return failure();
if (body.getBlocks().size() > 1)
return parser.emitError(loc) << "expects a single block region";
// Ensure that the region is well formed: it contains at least a block with
// a FetchOp terminator.
GraphOp::ensureTerminator(body, parser.getBuilder(), result.location);
if (!llvm::hasSingleElement(body))
return parser.emitError(loc) << "expects a single block region";
// Get the results type from the terminator type inside the graph.
Operation &fetch = body.back().back();
if (!isa<FetchOp>(fetch))

View File

@ -199,7 +199,7 @@ static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op,
// Folds merge nodes with only a single non-dead input.
static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
// Create builder for val_index of MergeOp.
auto* block = &function.getBlocks().front();
auto* block = &function.front();
OpBuilder builder = OpBuilder::atBlockEnd(block);
auto type = builder.getIntegerType(32);
auto build_index = [&](Location loc, int value) {

View File

@ -71,7 +71,7 @@ void MaterializePassthroughOpPass::runOnFunction() {
return;
}
Region &body = main.getBody();
if (body.getBlocks().size() != 1) {
if (!llvm::hasSingleElement(body)) {
op->emitError() << "MLIR Opaque Op expects a main() entry point with a "
"single block\n";
return;

View File

@ -80,11 +80,11 @@ constexpr char kResourceNameArgAttr[] = "tf.resource_name";
// Checks if a function has only one block.
mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) {
if (!hasSingleElement(function.getBlocks()))
if (!llvm::hasSingleElement(function)) {
return function.emitError()
<< "expects function '" << function.getName()
<< "' to have 1 block, got " << function.getBlocks().size();
}
return success();
}

View File

@ -1113,7 +1113,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
// This routine should only be called when control flow operations are still
// represented with TF IfOp and WhileOp operations. In this case, there should
// be only one basic blocks in the MLIR representation.
if (!hasSingleElement(function.getBlocks())) {
if (!llvm::hasSingleElement(function)) {
return function.emitError()
<< "expect the function to have 1 block while it has "
<< function.getBlocks().size();

View File

@ -159,8 +159,7 @@ llvm::SmallVector<FunctionAndArgumentInfo, 4> ExtractFunctionsConnectedToArg(
while (!functions_to_parse.empty()) {
llvm::SmallVector<FunctionAndArgumentInfo, 4> newly_discovered_functions;
for (auto function_info : functions_to_parse) {
Block& func_entry_block =
function_info.func.getBody().getBlocks().front();
Block& func_entry_block = function_info.func.front();
auto argument =
func_entry_block.getArgument(function_info.argument_index);
@ -186,8 +185,7 @@ void IdentifyXlaShardingForComputationInputs(
StringRef logical_core_0_sharding, tf_device::ClusterFuncOp cluster_func_op,
FuncOp cluster_function, Builder* builder) {
// Look up function definition from module.
Block& cluster_function_block =
cluster_function.getBody().getBlocks().front();
Block& cluster_function_block = cluster_function.front();
ModuleOp module = cluster_func_op.getParentOfType<ModuleOp>();
llvm::SmallVector<llvm::StringRef, 8> sharding_for_args(
@ -215,8 +213,7 @@ void IdentifyXlaShardingForComputationInputs(
const int function_argument_index = function_arg_info.argument_index;
auto& parsed_function = function_arg_info.func;
Block& parsed_function_block =
parsed_function.getBody().getBlocks().front();
Block& parsed_function_block = parsed_function.front();
arg_sharding = ParseInputSharding(
parsed_function_block.getArgument(function_argument_index));
}
@ -245,7 +242,7 @@ void IdentifyXlaShardingForComputationOutputs(
tf_device::ClusterFuncOp cluster_func, Builder* builder) {
// By default return values from logical core 0 is used if no sharding
// configuration is defined.
Block& function_block = func.getBody().getBlocks().front();
Block& function_block = func.front();
Operation* terminator = function_block.getTerminator();
llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets(
terminator->getNumOperands(), logical_core_0_sharding);

View File

@ -128,7 +128,7 @@ class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) {
Status status = Status::OK();
module.walk([&](mlir::FuncOp function) {
if (function.getBlocks().size() != 1) {
if (!llvm::hasSingleElement(function)) {
status = errors::FailedPrecondition(
kInvalidExecutorGraphMsg,
"only single block functions are supported.");

View File

@ -46,13 +46,13 @@ struct FunctionalToExecutorDialectConversion
} // end anonymous namespace
void FunctionalToExecutorDialectConversion::runOnFunction() {
if (getFunction().getBlocks().size() != 1) {
if (!llvm::hasSingleElement(getFunction())) {
LLVM_DEBUG(llvm::dbgs() << "Expect single block function, skip conversion "
"to tf_executor dialect\n");
return;
}
auto loc = getFunction().getLoc();
mlir::Block& body = getFunction().getBody().front();
mlir::Block& body = getFunction().front();
// Find region of interest and ReturnOp.
auto copy_range = body.without_terminator();
if (copy_range.begin() != copy_range.end() &&

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
@ -26,12 +27,12 @@ static mlir::Operation* ExtractOnlyOp(mlir::ModuleOp module) {
mlir::FuncOp fn = module.lookupSymbol<mlir::FuncOp>("main");
if (!fn) return nullptr;
if (fn.getBlocks().size() != 1) return nullptr;
if (!llvm::hasSingleElement(fn)) return nullptr;
// Here, modules with exactly two operations in the only basic block are
// supported. The last operation should be a terminator operation and the
// other operation is the operation of interest.
auto& block = fn.getBlocks().front();
auto& block = fn.front();
if (block.getOperations().size() != 2) return nullptr;
if (!block.back().isKnownTerminator()) return nullptr;

View File

@ -1148,13 +1148,13 @@ LogicalResult ConvertToHloModule::LowerFunctionCall(
LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
if (lowered_computation_.count(f)) return success();
if (f.getBlocks().size() != 1) {
if (!llvm::hasSingleElement(f)) {
return f.emitError("only single block Function supported");
}
// Create a sub-builder if this is not the main function.
std::unique_ptr<xla::XlaBuilder> builder_up;
bool entry_function = f.getName().str() == "main";
bool entry_function = f.getName() == "main";
if (!entry_function)
builder_up = module_builder_.CreateSubBuilder(f.getName().str());
auto& builder = entry_function ? module_builder_ : *builder_up;

View File

@ -230,10 +230,10 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
auto loc = op.getLoc();
// TODO(b/137624192) Implement variadic reduce.
if (op.getNumResults() != 1) return failure();
if (op.getParentRegion()->getBlocks().size() != 1) {
op.emitOpError() << "tensor to buffer conversion expects a single block "
if (!llvm::hasSingleElement(op.body())) {
return op.emitOpError()
<< "tensor to buffer conversion expects a single block "
"in the region containing the operation";
return failure();
}
const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
@ -320,13 +321,14 @@ LogicalResult FuncLegalizer::PrepareParams() {
}
LogicalResult FuncLegalizer::Legalize() {
if (func_.empty()) return success();
// TensorFlow functions don't use CFGs.
if (func_.getBlocks().size() > 1) {
if (!llvm::hasSingleElement(func_)) {
emitError(func_.getLoc()) << "requires at most one block in a TF function";
return failure();
}
if (func_.getBlocks().empty()) return success();
Block& block = func_.getBlocks().front();
Block& block = func_.front();
std::vector<Operation*> ops;
ops.reserve(block.getOperations().size());

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
@ -44,7 +45,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
auto func = getFunction();
// TODO(pifon): Remove assumption that the function has a single block.
if (func.getBlocks().size() != 1) {
if (!llvm::hasSingleElement(func)) {
emitError(func.getLoc(), "The function needs to have a single block.");
signalPassFailure();
return;
@ -58,7 +59,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
for (auto func_arg : func.getArguments()) {
result_buffers.insert(func_arg);
}
for (auto& block : func.getBlocks()) {
for (auto& block : func) {
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
if (!returnOp) continue;
for (auto operand : returnOp.getOperands()) {

View File

@ -487,7 +487,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
}
// process each block and do fusion within a block.
for (Block& block : func.getBlocks()) {
for (Block& block : func) {
SmallVector<Operation*, 4> op_list;
for (Operation& op : block) {
op_list.push_back(&op);

View File

@ -301,7 +301,7 @@ struct RewriteKernelSignature
signalPassFailure();
return;
}
if (func.getBlocks().size() != 1) {
if (!llvm::hasSingleElement(func)) {
func.emitError() << "surrounding function has more than one block";
signalPassFailure();
return;