[MLIR][NFC] Adopt variadic isa<>
PiperOrigin-RevId: 318279074 Change-Id: I9845b0278737a4d91b0e1e6699ae008d78e76556
This commit is contained in:
parent
884b55fe86
commit
9a99c02411
@ -190,9 +190,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
|
||||
}
|
||||
|
||||
static bool IsConst(Operation* op) {
|
||||
return isa<mlir::ConstantOp>(op) || isa<mlir::TF::ConstOp>(op) ||
|
||||
isa<tfl::ConstOp>(op) || isa<tfl::QConstOp>(op) ||
|
||||
isa<tfl::SparseConstOp>(op) || isa<tfl::SparseQConstOp>(op);
|
||||
return isa<mlir::ConstantOp, mlir::TF::ConstOp, tfl::ConstOp, tfl::QConstOp,
|
||||
tfl::SparseConstOp, tfl::SparseQConstOp>(op);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -289,8 +289,8 @@ class QuantizationDriver {
|
||||
llvm::errs() << "\n\n\n" << current_op->getName() << "\n";
|
||||
}
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (llvm::isa<quant::QuantizeCastOp>(op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<ConstantOp>(op))
|
||||
if (llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp, ConstantOp>(
|
||||
op))
|
||||
return;
|
||||
if (current_op == op) llvm::errs() << "===>>>";
|
||||
llvm::errs() << op->getName() << " : (";
|
||||
|
@ -172,7 +172,7 @@ struct QuantizationPattern : public RewritePattern {
|
||||
Value quantized_value = op->getResult(0);
|
||||
for (Operation* quantized_op : quantized_value.getUsers()) {
|
||||
// If it is requantize op, we shouldn't rewrite this op.
|
||||
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
|
||||
if (llvm::isa<Q, DQ>(quantized_op)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -180,8 +180,8 @@ struct QuantizationPattern : public RewritePattern {
|
||||
// ops dialect, we shouldn't rewrite.
|
||||
if (quantized_op->isKnownTerminator() ||
|
||||
quantized_op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::QuantizeCastOp>(quantized_op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(quantized_op)) {
|
||||
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(
|
||||
quantized_op)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -49,12 +49,10 @@ using mlir::OwningModuleRef;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
bool IsControlFlowV1Op(Operation* op) {
|
||||
return mlir::isa<mlir::tf_executor::SwitchOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::MergeOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::EnterOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::ExitOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::NextIterationSinkOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::NextIterationSourceOp>(op);
|
||||
return mlir::isa<mlir::tf_executor::SwitchOp, mlir::tf_executor::MergeOp,
|
||||
mlir::tf_executor::EnterOp, mlir::tf_executor::ExitOp,
|
||||
mlir::tf_executor::NextIterationSinkOp,
|
||||
mlir::tf_executor::NextIterationSourceOp>(op);
|
||||
}
|
||||
|
||||
mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
|
||||
|
@ -110,8 +110,7 @@ void DefaultQuantParamsPass::runOnFunction() {
|
||||
func.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator() ||
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::QuantizeCastOp>(op) ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op))
|
||||
llvm::isa<quant::QuantizeCastOp, quant::DequantizeCastOp>(op))
|
||||
return;
|
||||
|
||||
for (auto res : op->getResults()) {
|
||||
|
@ -100,8 +100,7 @@ int64_t FindPassthroughArgumentForReturnValue(int64_t return_index,
|
||||
value = graph.GetFetch().getOperand(res_num);
|
||||
} else if (auto island = llvm::dyn_cast<tf_executor::IslandOp>(op)) {
|
||||
value = island.GetYield().getOperand(res_num);
|
||||
} else if (llvm::isa<TF::IdentityNOp>(op) ||
|
||||
llvm::isa<TF::IdentityOp>(op)) {
|
||||
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
|
||||
value = op->getOperand(res_num);
|
||||
} else {
|
||||
return -1;
|
||||
|
@ -48,7 +48,7 @@ struct AnnotateParameterReplication
|
||||
// tf.IdentityOp or a tf.ReadVariableOp.
|
||||
Value SkipIdentityAndReadVariable(Value v) {
|
||||
while (auto op = v.getDefiningOp()) {
|
||||
if (!(isa<TF::IdentityOp>(op) || isa<TF::ReadVariableOp>(op))) break;
|
||||
if (!isa<TF::IdentityOp, TF::ReadVariableOp>(op)) break;
|
||||
v = op->getOperand(0);
|
||||
}
|
||||
return v;
|
||||
|
@ -219,8 +219,7 @@ llvm::Optional<RankedTensorType> GetElementTypeFromAccess(
|
||||
auto type_from_callee = GetElementTypeFromAccess(
|
||||
callee.getArgument(use.getOperandNumber()), module, infer_from_op);
|
||||
if (type_from_callee.hasValue()) return type_from_callee;
|
||||
} else if (llvm::isa<TF::IdentityOp>(use.getOwner()) ||
|
||||
llvm::isa<TF::IdentityNOp>(use.getOwner())) {
|
||||
} else if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(use.getOwner())) {
|
||||
auto type_from_alias = GetElementTypeFromAccess(
|
||||
use.getOwner()->getResult(use.getOperandNumber()), module,
|
||||
infer_from_op);
|
||||
|
@ -49,8 +49,7 @@ LogicalResult ConstantFoldFallbackHook(
|
||||
}
|
||||
|
||||
// Do not execute function calls.
|
||||
if (llvm::isa<TF::WhileOp>(inst) || llvm::isa<TF::IfOp>(inst) ||
|
||||
llvm::isa<CallOpInterface>(inst)) {
|
||||
if (llvm::isa<TF::WhileOp, TF::IfOp, CallOpInterface>(inst)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ struct FusedKernelMatcherPass
|
||||
};
|
||||
|
||||
bool IsActivationFunction(Operation *op) {
|
||||
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
|
||||
return isa<EluOp, ReluOp, Relu6Op>(op);
|
||||
}
|
||||
|
||||
// Finds and returns an activation op that uses the result of `op`. If there are
|
||||
|
@ -96,7 +96,7 @@ class ResourceAnalyzer {
|
||||
}
|
||||
|
||||
func.walk([&](Operation* op) {
|
||||
if (isa<TF::ReadVariableOp>(op) || isa<ReturnOp>(op)) {
|
||||
if (isa<TF::ReadVariableOp, ReturnOp>(op)) {
|
||||
return;
|
||||
}
|
||||
if (auto assign_variable = dyn_cast<TF::AssignVariableOp>(op)) {
|
||||
|
@ -97,8 +97,7 @@ llvm::SmallSet<llvm::StringRef, 1> GetCompositeResourceUserNames(
|
||||
// the error message are ordered.
|
||||
llvm::SmallSet<llvm::StringRef, 1> composite_users;
|
||||
for (Operation* user : resource.getUsers())
|
||||
if (!llvm::isa<TF::ReadVariableOp>(user) &&
|
||||
!llvm::isa<TF::AssignVariableOp>(user))
|
||||
if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(user))
|
||||
composite_users.insert(user->getName().getStringRef());
|
||||
|
||||
return composite_users;
|
||||
|
@ -53,8 +53,8 @@ struct ReplicateToIslandPass
|
||||
|
||||
// Returns whether op requires `_xla_replica_id` attribute.
|
||||
bool RequiresReplicaIDAttribute(Operation* op) {
|
||||
return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp>(op) ||
|
||||
llvm::isa<TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op);
|
||||
return llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
|
||||
TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op);
|
||||
}
|
||||
|
||||
// Adds integer attribute that represents replica id for replicated ops that
|
||||
|
@ -140,7 +140,7 @@ struct ResourceOpLiftingPass
|
||||
// such nodes to carry information.
|
||||
void RemoveIdentity(Block* block) {
|
||||
for (auto& op : llvm::make_early_inc_range(*block)) {
|
||||
if (isa<TF::IdentityOp>(&op) || isa<TF::IdentityNOp>(&op)) {
|
||||
if (isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
op.erase();
|
||||
}
|
||||
|
@ -114,14 +114,12 @@ Optional<SmallVector<Type, 4>> InferShapeForFunctionReturnType(FuncOp func) {
|
||||
|
||||
// Returns if the shape inference pass supports an op outside the TF dialect.
|
||||
bool IsSupportedNonTFOp(Operation* op) {
|
||||
return isa<ReturnOp>(op) || isa<tf_device::ReturnOp>(op) ||
|
||||
isa<tf_executor::EnterOp>(op) || isa<tf_executor::ExitOp>(op) ||
|
||||
isa<tf_executor::FetchOp>(op) || isa<tf_executor::GraphOp>(op) ||
|
||||
isa<tf_executor::IslandOp>(op) || isa<tf_executor::LoopCondOp>(op) ||
|
||||
isa<tf_executor::MergeOp>(op) ||
|
||||
isa<tf_executor::NextIterationSinkOp>(op) ||
|
||||
isa<tf_executor::SwitchNOp>(op) || isa<tf_executor::SwitchOp>(op) ||
|
||||
isa<tf_executor::YieldOp>(op);
|
||||
return isa<ReturnOp, tf_device::ReturnOp, tf_executor::EnterOp,
|
||||
tf_executor::ExitOp, tf_executor::FetchOp, tf_executor::GraphOp,
|
||||
tf_executor::IslandOp, tf_executor::LoopCondOp,
|
||||
tf_executor::MergeOp, tf_executor::NextIterationSinkOp,
|
||||
tf_executor::SwitchNOp, tf_executor::SwitchOp,
|
||||
tf_executor::YieldOp>(op);
|
||||
}
|
||||
|
||||
// Returns whether a cast back would need to be inserted, e.g., whether the
|
||||
|
@ -440,7 +440,7 @@ llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
|
||||
};
|
||||
for (FuncOp func : funcs) {
|
||||
for (auto& op : func.front().getOperations()) {
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
continue;
|
||||
}
|
||||
|
@ -640,7 +640,7 @@ LogicalResult DecomposeTensorListOpsInternal(
|
||||
decomposed_partitioned_call_callees) {
|
||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||
// TODO(yuanzx): Add a pass to remove identities in device computation.
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
op.erase();
|
||||
} else if (auto list = llvm::dyn_cast<TF::EmptyTensorListOp>(&op)) {
|
||||
|
@ -52,7 +52,7 @@ Operation* GetOpOfValue(Value value) {
|
||||
|
||||
// TODO(b/158596585): Replace this with a cost model analysis.
|
||||
bool IsTrivialUnaryOperation(Operation* op) {
|
||||
return llvm::isa<TF::CastOp>(op) || llvm::isa<TF::IdentityOp>(op);
|
||||
return llvm::isa<TF::CastOp, TF::IdentityOp>(op);
|
||||
}
|
||||
|
||||
// Adds outside compilation attributes to unary ops such as Identity/Cast ops
|
||||
|
@ -67,7 +67,7 @@ void GetAdjacentXlaShardingOp(Operation* op,
|
||||
return;
|
||||
}
|
||||
|
||||
if (llvm::isa<TF::IdentityOp>(op) || llvm::isa<TF::CastOp>(op)) {
|
||||
if (llvm::isa<TF::IdentityOp, TF::CastOp>(op)) {
|
||||
for (auto user : op->getUsers())
|
||||
GetAdjacentXlaShardingOp(user, sharding_op);
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ Value SkipIdentity(Value v, bool allow_other_use,
|
||||
while (auto result = v.dyn_cast<OpResult>()) {
|
||||
if (!(allow_other_use || v.hasOneUse())) break;
|
||||
auto op = result.getDefiningOp();
|
||||
if (!llvm::isa<TF::IdentityOp>(op) && !llvm::isa<TF::IdentityNOp>(op)) {
|
||||
if (!llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
|
||||
break;
|
||||
}
|
||||
v = op->getOperand(result.getResultNumber());
|
||||
|
@ -306,9 +306,8 @@ void BreakUpIslands::BreakUpIsland(
|
||||
llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
|
||||
(*new_control_inputs)[other_island_op].push_back(sink_island_control);
|
||||
} else if (owner->getDialect() == island_op.getDialect() &&
|
||||
!llvm::isa<tf_executor::GraphOp>(owner) &&
|
||||
!llvm::isa<tf_executor::YieldOp>(owner) &&
|
||||
!llvm::isa<tf_executor::NextIterationSourceOp>(owner)) {
|
||||
!llvm::isa<tf_executor::GraphOp, tf_executor::YieldOp,
|
||||
tf_executor::NextIterationSourceOp>(owner)) {
|
||||
(*new_control_inputs)[owner].push_back(sink_island_control);
|
||||
} else {
|
||||
owner->emitOpError("adding control dependency not supported");
|
||||
|
@ -1060,7 +1060,7 @@ LogicalResult ConvertToHloModule::Lower(
|
||||
return success();
|
||||
}
|
||||
|
||||
if (isa<xla_hlo::ReturnOp>(inst) || isa<mlir::ReturnOp>(inst)) {
|
||||
if (isa<xla_hlo::ReturnOp, mlir::ReturnOp>(inst)) {
|
||||
// Construct the return value for the function. If there are multiple
|
||||
// values returned, then create a tuple, else return value directly.
|
||||
xla::XlaOp return_value;
|
||||
|
@ -193,8 +193,7 @@ mlir::Operation* HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,
|
||||
|
||||
const bool any_op_is_loop_variant = [&] {
|
||||
for (mlir::Operation& op : llvm::make_range(begin_op, end_op)) {
|
||||
if (mlir::isa<mlir::AffineForOp>(op) ||
|
||||
mlir::isa<mlir::AffineStoreOp>(op)) {
|
||||
if (mlir::isa<mlir::AffineForOp, mlir::AffineStoreOp>(op)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -174,8 +174,7 @@ struct DeadTempBufferRemoval
|
||||
for (auto result : op->getResults()) {
|
||||
if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) {
|
||||
// Store and Dealloc is OK.
|
||||
if (llvm::isa<mlir::StoreOp>(op) ||
|
||||
llvm::isa<mlir::DeallocOp>(op)) {
|
||||
if (llvm::isa<mlir::StoreOp, mlir::DeallocOp>(op)) {
|
||||
return true;
|
||||
}
|
||||
// Load without uses is also ok.
|
||||
@ -225,8 +224,8 @@ struct MoveScalarComputationsIntoGpuLaunch
|
||||
: mlir::PassWrapper<MoveScalarComputationsIntoGpuLaunch,
|
||||
mlir::FunctionPass> {
|
||||
static bool isInliningBeneficiary(mlir::Operation* op) {
|
||||
return llvm::isa<mlir::ConstantOp>(op) || llvm::isa<mlir::DimOp>(op) ||
|
||||
llvm::isa<mlir::SelectOp>(op) || llvm::isa<mlir::CmpIOp>(op);
|
||||
return llvm::isa<mlir::ConstantOp, mlir::DimOp, mlir::SelectOp,
|
||||
mlir::CmpIOp>(op);
|
||||
}
|
||||
|
||||
static bool extractBeneficiaryOps(
|
||||
|
Loading…
x
Reference in New Issue
Block a user