[NFC] Eliminate use of .getBlocks() when not needed

Also use llvm::hasSingleElement() instead of .size() == 1

PiperOrigin-RevId: 317675565
Change-Id: I4f0e8892957c2b20e115584fe7424da68d53b67a
This commit is contained in:
Rahul Joshi 2020-06-22 09:58:59 -07:00 committed by TensorFlower Gardener
parent 2afae99d7f
commit 251923169d
18 changed files with 41 additions and 40 deletions

View File

@ -240,10 +240,10 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
} }
for (auto fn : module.getOps<FuncOp>()) { for (auto fn : module.getOps<FuncOp>()) {
if (fn.getBlocks().size() != 1) { if (!llvm::hasSingleElement(fn)) {
return fn.emitError("should have exactly one basic block"), false; return fn.emitError("should have exactly one basic block"), false;
} }
auto& bb = fn.getBlocks().front(); auto& bb = fn.front();
for (auto arg : bb.getArguments()) { for (auto arg : bb.getArguments()) {
if (!HasValidTFLiteType(arg, fn)) if (!HasValidTFLiteType(arg, fn))
@ -1089,7 +1089,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) { dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
str.getValue().split(output_names, ',', /*MaxSplit=*/-1, str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
/*KeepEmpty=*/false); /*KeepEmpty=*/false);
auto term = fn.getBlocks().back().getTerminator(); auto term = fn.back().getTerminator();
if (output_names.size() != term->getNumOperands()) { if (output_names.size() != term->getNumOperands()) {
fn.emitWarning() << "output names (" << output_names.size() fn.emitWarning() << "output names (" << output_names.size()
<< ") != terminator operands (" << term->getNumOperands() << ") != terminator operands (" << term->getNumOperands()

View File

@ -52,7 +52,7 @@ class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
void RemoveQuantizationAdaptorOps(FuncOp func) { void RemoveQuantizationAdaptorOps(FuncOp func) {
mlir::OpBuilder builder(func.getBody()); mlir::OpBuilder builder(func.getBody());
auto& bb = func.getBlocks().front(); auto& bb = func.front();
auto* terminator = bb.getTerminator(); auto* terminator = bb.getTerminator();
int num_args = bb.getNumArguments(); int num_args = bb.getNumArguments();

View File

@ -299,13 +299,13 @@ ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
parser->parseRegion(body, region_args, region_arg_types)) parser->parseRegion(body, region_args, region_arg_types))
return failure(); return failure();
if (body.getBlocks().size() > 1)
return parser->emitError(loc) << "expects a single block region";
// Ensure that the region is well formed: it contains at least a block with // Ensure that the region is well formed: it contains at least a block with
// a ReturnOp terminator. // a ReturnOp terminator.
ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location); ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
if (!llvm::hasSingleElement(body))
return parser->emitError(loc) << "expects a single block region";
Operation& terminator = body.front().back(); Operation& terminator = body.front().back();
if (!isa<ReturnOp>(terminator)) if (!isa<ReturnOp>(terminator))
return parser->emitError(loc) << "expects a tf_device.return terminator"; return parser->emitError(loc) << "expects a tf_device.return terminator";

View File

@ -220,13 +220,13 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
Region &body = *result.addRegion(); Region &body = *result.addRegion();
if (parser.parseRegion(body, llvm::None, llvm::None)) return failure(); if (parser.parseRegion(body, llvm::None, llvm::None)) return failure();
if (body.getBlocks().size() > 1)
return parser.emitError(loc) << "expects a single block region";
// Ensure that the region is well formed: it contains at least a block with // Ensure that the region is well formed: it contains at least a block with
// a FetchOp terminator. // a FetchOp terminator.
GraphOp::ensureTerminator(body, parser.getBuilder(), result.location); GraphOp::ensureTerminator(body, parser.getBuilder(), result.location);
if (!llvm::hasSingleElement(body))
return parser.emitError(loc) << "expects a single block region";
// Get the results type from the terminator type inside the graph. // Get the results type from the terminator type inside the graph.
Operation &fetch = body.back().back(); Operation &fetch = body.back().back();
if (!isa<FetchOp>(fetch)) if (!isa<FetchOp>(fetch))

View File

@ -199,7 +199,7 @@ static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op,
// Folds merge nodes with only a single non-dead input. // Folds merge nodes with only a single non-dead input.
static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) { static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
// Create builder for val_index of MergeOp. // Create builder for val_index of MergeOp.
auto* block = &function.getBlocks().front(); auto* block = &function.front();
OpBuilder builder = OpBuilder::atBlockEnd(block); OpBuilder builder = OpBuilder::atBlockEnd(block);
auto type = builder.getIntegerType(32); auto type = builder.getIntegerType(32);
auto build_index = [&](Location loc, int value) { auto build_index = [&](Location loc, int value) {

View File

@ -71,7 +71,7 @@ void MaterializePassthroughOpPass::runOnFunction() {
return; return;
} }
Region &body = main.getBody(); Region &body = main.getBody();
if (body.getBlocks().size() != 1) { if (!llvm::hasSingleElement(body)) {
op->emitError() << "MLIR Opaque Op expects a main() entry point with a " op->emitError() << "MLIR Opaque Op expects a main() entry point with a "
"single block\n"; "single block\n";
return; return;

View File

@ -80,11 +80,11 @@ constexpr char kResourceNameArgAttr[] = "tf.resource_name";
// Checks if a function has only one block. // Checks if a function has only one block.
mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) { mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) {
if (!hasSingleElement(function.getBlocks())) if (!llvm::hasSingleElement(function)) {
return function.emitError() return function.emitError()
<< "expects function '" << function.getName() << "expects function '" << function.getName()
<< "' to have 1 block, got " << function.getBlocks().size(); << "' to have 1 block, got " << function.getBlocks().size();
}
return success(); return success();
} }

View File

@ -1113,7 +1113,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
// This routine should only be called when control flow operations are still // This routine should only be called when control flow operations are still
// represented with TF IfOp and WhileOp operations. In this case, there should // represented with TF IfOp and WhileOp operations. In this case, there should
// be only one basic blocks in the MLIR representation. // be only one basic blocks in the MLIR representation.
if (!hasSingleElement(function.getBlocks())) { if (!llvm::hasSingleElement(function)) {
return function.emitError() return function.emitError()
<< "expect the function to have 1 block while it has " << "expect the function to have 1 block while it has "
<< function.getBlocks().size(); << function.getBlocks().size();

View File

@ -159,8 +159,7 @@ llvm::SmallVector<FunctionAndArgumentInfo, 4> ExtractFunctionsConnectedToArg(
while (!functions_to_parse.empty()) { while (!functions_to_parse.empty()) {
llvm::SmallVector<FunctionAndArgumentInfo, 4> newly_discovered_functions; llvm::SmallVector<FunctionAndArgumentInfo, 4> newly_discovered_functions;
for (auto function_info : functions_to_parse) { for (auto function_info : functions_to_parse) {
Block& func_entry_block = Block& func_entry_block = function_info.func.front();
function_info.func.getBody().getBlocks().front();
auto argument = auto argument =
func_entry_block.getArgument(function_info.argument_index); func_entry_block.getArgument(function_info.argument_index);
@ -186,8 +185,7 @@ void IdentifyXlaShardingForComputationInputs(
StringRef logical_core_0_sharding, tf_device::ClusterFuncOp cluster_func_op, StringRef logical_core_0_sharding, tf_device::ClusterFuncOp cluster_func_op,
FuncOp cluster_function, Builder* builder) { FuncOp cluster_function, Builder* builder) {
// Look up function definition from module. // Look up function definition from module.
Block& cluster_function_block = Block& cluster_function_block = cluster_function.front();
cluster_function.getBody().getBlocks().front();
ModuleOp module = cluster_func_op.getParentOfType<ModuleOp>(); ModuleOp module = cluster_func_op.getParentOfType<ModuleOp>();
llvm::SmallVector<llvm::StringRef, 8> sharding_for_args( llvm::SmallVector<llvm::StringRef, 8> sharding_for_args(
@ -215,8 +213,7 @@ void IdentifyXlaShardingForComputationInputs(
const int function_argument_index = function_arg_info.argument_index; const int function_argument_index = function_arg_info.argument_index;
auto& parsed_function = function_arg_info.func; auto& parsed_function = function_arg_info.func;
Block& parsed_function_block = Block& parsed_function_block = parsed_function.front();
parsed_function.getBody().getBlocks().front();
arg_sharding = ParseInputSharding( arg_sharding = ParseInputSharding(
parsed_function_block.getArgument(function_argument_index)); parsed_function_block.getArgument(function_argument_index));
} }
@ -245,7 +242,7 @@ void IdentifyXlaShardingForComputationOutputs(
tf_device::ClusterFuncOp cluster_func, Builder* builder) { tf_device::ClusterFuncOp cluster_func, Builder* builder) {
// By default return values from logical core 0 is used if no sharding // By default return values from logical core 0 is used if no sharding
// configuration is defined. // configuration is defined.
Block& function_block = func.getBody().getBlocks().front(); Block& function_block = func.front();
Operation* terminator = function_block.getTerminator(); Operation* terminator = function_block.getTerminator();
llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets( llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets(
terminator->getNumOperands(), logical_core_0_sharding); terminator->getNumOperands(), logical_core_0_sharding);

View File

@ -128,7 +128,7 @@ class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) { Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) {
Status status = Status::OK(); Status status = Status::OK();
module.walk([&](mlir::FuncOp function) { module.walk([&](mlir::FuncOp function) {
if (function.getBlocks().size() != 1) { if (!llvm::hasSingleElement(function)) {
status = errors::FailedPrecondition( status = errors::FailedPrecondition(
kInvalidExecutorGraphMsg, kInvalidExecutorGraphMsg,
"only single block functions are supported."); "only single block functions are supported.");

View File

@ -46,13 +46,13 @@ struct FunctionalToExecutorDialectConversion
} // end anonymous namespace } // end anonymous namespace
void FunctionalToExecutorDialectConversion::runOnFunction() { void FunctionalToExecutorDialectConversion::runOnFunction() {
if (getFunction().getBlocks().size() != 1) { if (!llvm::hasSingleElement(getFunction())) {
LLVM_DEBUG(llvm::dbgs() << "Expect single block function, skip conversion " LLVM_DEBUG(llvm::dbgs() << "Expect single block function, skip conversion "
"to tf_executor dialect\n"); "to tf_executor dialect\n");
return; return;
} }
auto loc = getFunction().getLoc(); auto loc = getFunction().getLoc();
mlir::Block& body = getFunction().getBody().front(); mlir::Block& body = getFunction().front();
// Find region of interest and ReturnOp. // Find region of interest and ReturnOp.
auto copy_range = body.without_terminator(); auto copy_range = body.without_terminator();
if (copy_range.begin() != copy_range.end() && if (copy_range.begin() != copy_range.end() &&

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project
@ -26,12 +27,12 @@ static mlir::Operation* ExtractOnlyOp(mlir::ModuleOp module) {
mlir::FuncOp fn = module.lookupSymbol<mlir::FuncOp>("main"); mlir::FuncOp fn = module.lookupSymbol<mlir::FuncOp>("main");
if (!fn) return nullptr; if (!fn) return nullptr;
if (fn.getBlocks().size() != 1) return nullptr; if (!llvm::hasSingleElement(fn)) return nullptr;
// Here, modules with exactly two operations in the only basic block are // Here, modules with exactly two operations in the only basic block are
// supported. The last operation should be a terminator operation and the // supported. The last operation should be a terminator operation and the
// other operation is the operation of interest. // other operation is the operation of interest.
auto& block = fn.getBlocks().front(); auto& block = fn.front();
if (block.getOperations().size() != 2) return nullptr; if (block.getOperations().size() != 2) return nullptr;
if (!block.back().isKnownTerminator()) return nullptr; if (!block.back().isKnownTerminator()) return nullptr;

View File

@ -1148,13 +1148,13 @@ LogicalResult ConvertToHloModule::LowerFunctionCall(
LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
if (lowered_computation_.count(f)) return success(); if (lowered_computation_.count(f)) return success();
if (f.getBlocks().size() != 1) { if (!llvm::hasSingleElement(f)) {
return f.emitError("only single block Function supported"); return f.emitError("only single block Function supported");
} }
// Create a sub-builder if this is not the main function. // Create a sub-builder if this is not the main function.
std::unique_ptr<xla::XlaBuilder> builder_up; std::unique_ptr<xla::XlaBuilder> builder_up;
bool entry_function = f.getName().str() == "main"; bool entry_function = f.getName() == "main";
if (!entry_function) if (!entry_function)
builder_up = module_builder_.CreateSubBuilder(f.getName().str()); builder_up = module_builder_.CreateSubBuilder(f.getName().str());
auto& builder = entry_function ? module_builder_ : *builder_up; auto& builder = entry_function ? module_builder_ : *builder_up;

View File

@ -230,10 +230,10 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
auto loc = op.getLoc(); auto loc = op.getLoc();
// TODO(b/137624192) Implement variadic reduce. // TODO(b/137624192) Implement variadic reduce.
if (op.getNumResults() != 1) return failure(); if (op.getNumResults() != 1) return failure();
if (op.getParentRegion()->getBlocks().size() != 1) { if (!llvm::hasSingleElement(op.body())) {
op.emitOpError() << "tensor to buffer conversion expects a single block " return op.emitOpError()
<< "tensor to buffer conversion expects a single block "
"in the region containing the operation"; "in the region containing the operation";
return failure();
} }
const auto& original_results = op.getResults(); const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end()); SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Optional.h" #include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project
@ -320,13 +321,14 @@ LogicalResult FuncLegalizer::PrepareParams() {
} }
LogicalResult FuncLegalizer::Legalize() { LogicalResult FuncLegalizer::Legalize() {
if (func_.empty()) return success();
// TensorFlow functions don't use CFGs. // TensorFlow functions don't use CFGs.
if (func_.getBlocks().size() > 1) { if (!llvm::hasSingleElement(func_)) {
emitError(func_.getLoc()) << "requires at most one block in a TF function"; emitError(func_.getLoc()) << "requires at most one block in a TF function";
return failure(); return failure();
} }
if (func_.getBlocks().empty()) return success(); Block& block = func_.front();
Block& block = func_.getBlocks().front();
std::vector<Operation*> ops; std::vector<Operation*> ops;
ops.reserve(block.getOperations().size()); ops.reserve(block.getOperations().size());

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project
@ -44,7 +45,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
auto func = getFunction(); auto func = getFunction();
// TODO(pifon): Remove assumption that the function has a single block. // TODO(pifon): Remove assumption that the function has a single block.
if (func.getBlocks().size() != 1) { if (!llvm::hasSingleElement(func)) {
emitError(func.getLoc(), "The function needs to have a single block."); emitError(func.getLoc(), "The function needs to have a single block.");
signalPassFailure(); signalPassFailure();
return; return;
@ -58,7 +59,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
for (auto func_arg : func.getArguments()) { for (auto func_arg : func.getArguments()) {
result_buffers.insert(func_arg); result_buffers.insert(func_arg);
} }
for (auto& block : func.getBlocks()) { for (auto& block : func) {
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator()); auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
if (!returnOp) continue; if (!returnOp) continue;
for (auto operand : returnOp.getOperands()) { for (auto operand : returnOp.getOperands()) {

View File

@ -487,7 +487,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
} }
// process each block and do fusion within a block. // process each block and do fusion within a block.
for (Block& block : func.getBlocks()) { for (Block& block : func) {
SmallVector<Operation*, 4> op_list; SmallVector<Operation*, 4> op_list;
for (Operation& op : block) { for (Operation& op : block) {
op_list.push_back(&op); op_list.push_back(&op);

View File

@ -301,7 +301,7 @@ struct RewriteKernelSignature
signalPassFailure(); signalPassFailure();
return; return;
} }
if (func.getBlocks().size() != 1) { if (!llvm::hasSingleElement(func)) {
func.emitError() << "surrounding function has more than one block"; func.emitError() << "surrounding function has more than one block";
signalPassFailure(); signalPassFailure();
return; return;