[NFC] Eliminate use of .getBlocks() when not needed
Also use llvm::hasSingleElement() instead of .size() == 1 PiperOrigin-RevId: 317675565 Change-Id: I4f0e8892957c2b20e115584fe7424da68d53b67a
This commit is contained in:
parent
2afae99d7f
commit
251923169d
@ -240,10 +240,10 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
||||
}
|
||||
|
||||
for (auto fn : module.getOps<FuncOp>()) {
|
||||
if (fn.getBlocks().size() != 1) {
|
||||
if (!llvm::hasSingleElement(fn)) {
|
||||
return fn.emitError("should have exactly one basic block"), false;
|
||||
}
|
||||
auto& bb = fn.getBlocks().front();
|
||||
auto& bb = fn.front();
|
||||
|
||||
for (auto arg : bb.getArguments()) {
|
||||
if (!HasValidTFLiteType(arg, fn))
|
||||
@ -1089,7 +1089,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
|
||||
dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
|
||||
str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
|
||||
/*KeepEmpty=*/false);
|
||||
auto term = fn.getBlocks().back().getTerminator();
|
||||
auto term = fn.back().getTerminator();
|
||||
if (output_names.size() != term->getNumOperands()) {
|
||||
fn.emitWarning() << "output names (" << output_names.size()
|
||||
<< ") != terminator operands (" << term->getNumOperands()
|
||||
|
@ -52,7 +52,7 @@ class PostQuantizePass : public PassWrapper<PostQuantizePass, FunctionPass> {
|
||||
|
||||
void RemoveQuantizationAdaptorOps(FuncOp func) {
|
||||
mlir::OpBuilder builder(func.getBody());
|
||||
auto& bb = func.getBlocks().front();
|
||||
auto& bb = func.front();
|
||||
auto* terminator = bb.getTerminator();
|
||||
|
||||
int num_args = bb.getNumArguments();
|
||||
|
@ -299,13 +299,13 @@ ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
|
||||
parser->parseRegion(body, region_args, region_arg_types))
|
||||
return failure();
|
||||
|
||||
if (body.getBlocks().size() > 1)
|
||||
return parser->emitError(loc) << "expects a single block region";
|
||||
|
||||
// Ensure that the region is well formed: it contains at least a block with
|
||||
// a ReturnOp terminator.
|
||||
ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
|
||||
|
||||
if (!llvm::hasSingleElement(body))
|
||||
return parser->emitError(loc) << "expects a single block region";
|
||||
|
||||
Operation& terminator = body.front().back();
|
||||
if (!isa<ReturnOp>(terminator))
|
||||
return parser->emitError(loc) << "expects a tf_device.return terminator";
|
||||
|
@ -220,13 +220,13 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) {
|
||||
Region &body = *result.addRegion();
|
||||
if (parser.parseRegion(body, llvm::None, llvm::None)) return failure();
|
||||
|
||||
if (body.getBlocks().size() > 1)
|
||||
return parser.emitError(loc) << "expects a single block region";
|
||||
|
||||
// Ensure that the region is well formed: it contains at least a block with
|
||||
// a FetchOp terminator.
|
||||
GraphOp::ensureTerminator(body, parser.getBuilder(), result.location);
|
||||
|
||||
if (!llvm::hasSingleElement(body))
|
||||
return parser.emitError(loc) << "expects a single block region";
|
||||
|
||||
// Get the results type from the terminator type inside the graph.
|
||||
Operation &fetch = body.back().back();
|
||||
if (!isa<FetchOp>(fetch))
|
||||
|
@ -199,7 +199,7 @@ static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op,
|
||||
// Folds merge nodes with only a single non-dead input.
|
||||
static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
|
||||
// Create builder for val_index of MergeOp.
|
||||
auto* block = &function.getBlocks().front();
|
||||
auto* block = &function.front();
|
||||
OpBuilder builder = OpBuilder::atBlockEnd(block);
|
||||
auto type = builder.getIntegerType(32);
|
||||
auto build_index = [&](Location loc, int value) {
|
||||
|
@ -71,7 +71,7 @@ void MaterializePassthroughOpPass::runOnFunction() {
|
||||
return;
|
||||
}
|
||||
Region &body = main.getBody();
|
||||
if (body.getBlocks().size() != 1) {
|
||||
if (!llvm::hasSingleElement(body)) {
|
||||
op->emitError() << "MLIR Opaque Op expects a main() entry point with a "
|
||||
"single block\n";
|
||||
return;
|
||||
|
@ -80,11 +80,11 @@ constexpr char kResourceNameArgAttr[] = "tf.resource_name";
|
||||
|
||||
// Checks if a function has only one block.
|
||||
mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) {
|
||||
if (!hasSingleElement(function.getBlocks()))
|
||||
if (!llvm::hasSingleElement(function)) {
|
||||
return function.emitError()
|
||||
<< "expects function '" << function.getName()
|
||||
<< "' to have 1 block, got " << function.getBlocks().size();
|
||||
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1113,7 +1113,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
|
||||
// This routine should only be called when control flow operations are still
|
||||
// represented with TF IfOp and WhileOp operations. In this case, there should
|
||||
// be only one basic blocks in the MLIR representation.
|
||||
if (!hasSingleElement(function.getBlocks())) {
|
||||
if (!llvm::hasSingleElement(function)) {
|
||||
return function.emitError()
|
||||
<< "expect the function to have 1 block while it has "
|
||||
<< function.getBlocks().size();
|
||||
|
@ -159,8 +159,7 @@ llvm::SmallVector<FunctionAndArgumentInfo, 4> ExtractFunctionsConnectedToArg(
|
||||
while (!functions_to_parse.empty()) {
|
||||
llvm::SmallVector<FunctionAndArgumentInfo, 4> newly_discovered_functions;
|
||||
for (auto function_info : functions_to_parse) {
|
||||
Block& func_entry_block =
|
||||
function_info.func.getBody().getBlocks().front();
|
||||
Block& func_entry_block = function_info.func.front();
|
||||
auto argument =
|
||||
func_entry_block.getArgument(function_info.argument_index);
|
||||
|
||||
@ -186,8 +185,7 @@ void IdentifyXlaShardingForComputationInputs(
|
||||
StringRef logical_core_0_sharding, tf_device::ClusterFuncOp cluster_func_op,
|
||||
FuncOp cluster_function, Builder* builder) {
|
||||
// Look up function definition from module.
|
||||
Block& cluster_function_block =
|
||||
cluster_function.getBody().getBlocks().front();
|
||||
Block& cluster_function_block = cluster_function.front();
|
||||
ModuleOp module = cluster_func_op.getParentOfType<ModuleOp>();
|
||||
|
||||
llvm::SmallVector<llvm::StringRef, 8> sharding_for_args(
|
||||
@ -215,8 +213,7 @@ void IdentifyXlaShardingForComputationInputs(
|
||||
|
||||
const int function_argument_index = function_arg_info.argument_index;
|
||||
auto& parsed_function = function_arg_info.func;
|
||||
Block& parsed_function_block =
|
||||
parsed_function.getBody().getBlocks().front();
|
||||
Block& parsed_function_block = parsed_function.front();
|
||||
arg_sharding = ParseInputSharding(
|
||||
parsed_function_block.getArgument(function_argument_index));
|
||||
}
|
||||
@ -245,7 +242,7 @@ void IdentifyXlaShardingForComputationOutputs(
|
||||
tf_device::ClusterFuncOp cluster_func, Builder* builder) {
|
||||
// By default return values from logical core 0 is used if no sharding
|
||||
// configuration is defined.
|
||||
Block& function_block = func.getBody().getBlocks().front();
|
||||
Block& function_block = func.front();
|
||||
Operation* terminator = function_block.getTerminator();
|
||||
llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets(
|
||||
terminator->getNumOperands(), logical_core_0_sharding);
|
||||
|
@ -128,7 +128,7 @@ class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper {
|
||||
Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) {
|
||||
Status status = Status::OK();
|
||||
module.walk([&](mlir::FuncOp function) {
|
||||
if (function.getBlocks().size() != 1) {
|
||||
if (!llvm::hasSingleElement(function)) {
|
||||
status = errors::FailedPrecondition(
|
||||
kInvalidExecutorGraphMsg,
|
||||
"only single block functions are supported.");
|
||||
|
@ -46,13 +46,13 @@ struct FunctionalToExecutorDialectConversion
|
||||
} // end anonymous namespace
|
||||
|
||||
void FunctionalToExecutorDialectConversion::runOnFunction() {
|
||||
if (getFunction().getBlocks().size() != 1) {
|
||||
if (!llvm::hasSingleElement(getFunction())) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Expect single block function, skip conversion "
|
||||
"to tf_executor dialect\n");
|
||||
return;
|
||||
}
|
||||
auto loc = getFunction().getLoc();
|
||||
mlir::Block& body = getFunction().getBody().front();
|
||||
mlir::Block& body = getFunction().front();
|
||||
// Find region of interest and ReturnOp.
|
||||
auto copy_range = body.without_terminator();
|
||||
if (copy_range.begin() != copy_range.end() &&
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
@ -26,12 +27,12 @@ static mlir::Operation* ExtractOnlyOp(mlir::ModuleOp module) {
|
||||
mlir::FuncOp fn = module.lookupSymbol<mlir::FuncOp>("main");
|
||||
if (!fn) return nullptr;
|
||||
|
||||
if (fn.getBlocks().size() != 1) return nullptr;
|
||||
if (!llvm::hasSingleElement(fn)) return nullptr;
|
||||
|
||||
// Here, modules with exactly two operations in the only basic block are
|
||||
// supported. The last operation should be a terminator operation and the
|
||||
// other operation is the operation of interest.
|
||||
auto& block = fn.getBlocks().front();
|
||||
auto& block = fn.front();
|
||||
if (block.getOperations().size() != 2) return nullptr;
|
||||
if (!block.back().isKnownTerminator()) return nullptr;
|
||||
|
||||
|
@ -1148,13 +1148,13 @@ LogicalResult ConvertToHloModule::LowerFunctionCall(
|
||||
|
||||
LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
|
||||
if (lowered_computation_.count(f)) return success();
|
||||
if (f.getBlocks().size() != 1) {
|
||||
if (!llvm::hasSingleElement(f)) {
|
||||
return f.emitError("only single block Function supported");
|
||||
}
|
||||
|
||||
// Create a sub-builder if this is not the main function.
|
||||
std::unique_ptr<xla::XlaBuilder> builder_up;
|
||||
bool entry_function = f.getName().str() == "main";
|
||||
bool entry_function = f.getName() == "main";
|
||||
if (!entry_function)
|
||||
builder_up = module_builder_.CreateSubBuilder(f.getName().str());
|
||||
auto& builder = entry_function ? module_builder_ : *builder_up;
|
||||
|
@ -230,10 +230,10 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
|
||||
auto loc = op.getLoc();
|
||||
// TODO(b/137624192) Implement variadic reduce.
|
||||
if (op.getNumResults() != 1) return failure();
|
||||
if (op.getParentRegion()->getBlocks().size() != 1) {
|
||||
op.emitOpError() << "tensor to buffer conversion expects a single block "
|
||||
"in the region containing the operation";
|
||||
return failure();
|
||||
if (!llvm::hasSingleElement(op.body())) {
|
||||
return op.emitOpError()
|
||||
<< "tensor to buffer conversion expects a single block "
|
||||
"in the region containing the operation";
|
||||
}
|
||||
const auto& original_results = op.getResults();
|
||||
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
@ -320,13 +321,14 @@ LogicalResult FuncLegalizer::PrepareParams() {
|
||||
}
|
||||
|
||||
LogicalResult FuncLegalizer::Legalize() {
|
||||
if (func_.empty()) return success();
|
||||
|
||||
// TensorFlow functions don't use CFGs.
|
||||
if (func_.getBlocks().size() > 1) {
|
||||
if (!llvm::hasSingleElement(func_)) {
|
||||
emitError(func_.getLoc()) << "requires at most one block in a TF function";
|
||||
return failure();
|
||||
}
|
||||
if (func_.getBlocks().empty()) return success();
|
||||
Block& block = func_.getBlocks().front();
|
||||
Block& block = func_.front();
|
||||
|
||||
std::vector<Operation*> ops;
|
||||
ops.reserve(block.getOperations().size());
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
@ -44,7 +45,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
||||
auto func = getFunction();
|
||||
|
||||
// TODO(pifon): Remove assumption that the function has a single block.
|
||||
if (func.getBlocks().size() != 1) {
|
||||
if (!llvm::hasSingleElement(func)) {
|
||||
emitError(func.getLoc(), "The function needs to have a single block.");
|
||||
signalPassFailure();
|
||||
return;
|
||||
@ -58,7 +59,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
||||
for (auto func_arg : func.getArguments()) {
|
||||
result_buffers.insert(func_arg);
|
||||
}
|
||||
for (auto& block : func.getBlocks()) {
|
||||
for (auto& block : func) {
|
||||
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
|
||||
if (!returnOp) continue;
|
||||
for (auto operand : returnOp.getOperands()) {
|
||||
|
@ -487,7 +487,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
||||
}
|
||||
|
||||
// process each block and do fusion within a block.
|
||||
for (Block& block : func.getBlocks()) {
|
||||
for (Block& block : func) {
|
||||
SmallVector<Operation*, 4> op_list;
|
||||
for (Operation& op : block) {
|
||||
op_list.push_back(&op);
|
||||
|
@ -301,7 +301,7 @@ struct RewriteKernelSignature
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (func.getBlocks().size() != 1) {
|
||||
if (!llvm::hasSingleElement(func)) {
|
||||
func.emitError() << "surrounding function has more than one block";
|
||||
signalPassFailure();
|
||||
return;
|
||||
|
Loading…
Reference in New Issue
Block a user