From 9a99c02411ed46af1304cf1c393ee1e6df1907d2 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Thu, 25 Jun 2020 08:45:22 -0700 Subject: [PATCH] [MLIR][NFC] Adopt variadic isa<> PiperOrigin-RevId: 318279074 Change-Id: I9845b0278737a4d91b0e1e6699ae008d78e76556 --- tensorflow/compiler/mlir/lite/flatbuffer_export.cc | 5 ++--- .../mlir/lite/quantization/quantization_driver.cc | 4 ++-- .../mlir/lite/quantization/quantization_utils.h | 6 +++--- .../compiler/mlir/lite/tf_to_tfl_flatbuffer.cc | 10 ++++------ .../mlir/lite/transforms/default_quant_params.cc | 3 +-- .../tensorflow/analysis/side_effect_analysis.cc | 3 +-- .../transforms/annotate_parameter_replication.cc | 2 +- .../tensorflow/transforms/collection_ops_util.cc | 3 +-- .../mlir/tensorflow/transforms/constant_fold.cc | 3 +-- .../tensorflow/transforms/fused_kernel_matcher.cc | 2 +- .../transforms/optimize_global_tensors.cc | 2 +- .../transforms/promote_resources_to_args.cc | 3 +-- .../tensorflow/transforms/replicate_to_island.cc | 4 ++-- .../tensorflow/transforms/resource_op_lifting.cc | 2 +- .../mlir/tensorflow/transforms/shape_inference.cc | 14 ++++++-------- .../transforms/tensor_array_ops_decomposition.cc | 2 +- .../transforms/tensor_list_ops_decomposition.cc | 2 +- .../transforms/tpu_host_computation_expansion.cc | 2 +- .../transforms/tpu_sharding_identification_pass.cc | 2 +- .../tpu_variable_runtime_reformatting.cc | 2 +- .../mlir/tensorflow/translate/breakup-islands.cc | 5 ++--- tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc | 2 +- .../experimental/conv_emitter/conv_emitter.cc | 3 +-- .../xla/service/mlir_gpu/kernel_lowering.cc | 7 +++---- 24 files changed, 40 insertions(+), 53 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index e34e7ae7ca6..ee8b34598e2 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -190,9 +190,8 @@ static StatusOr GetTFLiteType(Type type, } static bool IsConst(Operation* op) { - return isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op) || isa(op); + return isa(op); } template diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index f3e746c7a43..bc97c42c955 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -289,8 +289,8 @@ class QuantizationDriver { llvm::errs() << "\n\n\n" << current_op->getName() << "\n"; } fn_.walk([&](Operation *op) { - if (llvm::isa(op) || - llvm::isa(op) || llvm::isa(op)) + if (llvm::isa( + op)) return; if (current_op == op) llvm::errs() << "===>>>"; llvm::errs() << op->getName() << " : ("; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index f17e44cd756..ad99b1c58d2 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -172,7 +172,7 @@ struct QuantizationPattern : public RewritePattern { Value quantized_value = op->getResult(0); for (Operation* quantized_op : quantized_value.getUsers()) { // If it is requantize op, we shouldn't rewrite this op. - if (llvm::isa(quantized_op) || llvm::isa(quantized_op)) { + if (llvm::isa(quantized_op)) { return failure(); } @@ -180,8 +180,8 @@ struct QuantizationPattern : public RewritePattern { // ops dialect, we shouldn't rewrite. if (quantized_op->isKnownTerminator() || quantized_op->hasTrait() || - llvm::isa(quantized_op) || - llvm::isa(quantized_op)) { + llvm::isa( + quantized_op)) { return failure(); } diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 2e45953c5fa..aa89472f92a 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -49,12 +49,10 @@ using mlir::OwningModuleRef; using stream_executor::port::StatusOr; bool IsControlFlowV1Op(Operation* op) { - return mlir::isa(op) || - mlir::isa(op) || - mlir::isa(op) || - mlir::isa(op) || - mlir::isa(op) || - mlir::isa(op); + return mlir::isa(op); } mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) { diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index c23ae9fcfab..451eb613543 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -110,8 +110,7 @@ void DefaultQuantParamsPass::runOnFunction() { func.walk([&](Operation *op) { if (op->isKnownTerminator() || op->hasTrait() || - llvm::isa(op) || - llvm::isa(op)) + llvm::isa(op)) return; for (auto res : op->getResults()) { diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index f7b88317cd4..35f02ba8445 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -100,8 +100,7 @@ int64_t FindPassthroughArgumentForReturnValue(int64_t return_index, value = graph.GetFetch().getOperand(res_num); } else if (auto island = llvm::dyn_cast(op)) { value = island.GetYield().getOperand(res_num); - } else if (llvm::isa(op) || - llvm::isa(op)) { + } else if (llvm::isa(op)) { value = op->getOperand(res_num); } else { return -1; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index 6ba6f416c70..dc24e478378 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -48,7 +48,7 @@ struct AnnotateParameterReplication // tf.IdentityOp or a tf.ReadVariableOp. Value SkipIdentityAndReadVariable(Value v) { while (auto op = v.getDefiningOp()) { - if (!(isa(op) || isa(op))) break; + if (!isa(op)) break; v = op->getOperand(0); } return v; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index 8951b49acb7..58c4eac5c95 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -219,8 +219,7 @@ llvm::Optional GetElementTypeFromAccess( auto type_from_callee = GetElementTypeFromAccess( callee.getArgument(use.getOperandNumber()), module, infer_from_op); if (type_from_callee.hasValue()) return type_from_callee; - } else if (llvm::isa(use.getOwner()) || - llvm::isa(use.getOwner())) { + } else if (llvm::isa(use.getOwner())) { auto type_from_alias = GetElementTypeFromAccess( use.getOwner()->getResult(use.getOperandNumber()), module, infer_from_op); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 55a0b5c3fd3..16de2874fda 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -49,8 +49,7 @@ LogicalResult ConstantFoldFallbackHook( } // Do not execute function calls. - if (llvm::isa(inst) || llvm::isa(inst) || - llvm::isa(inst)) { + if (llvm::isa(inst)) { return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc index d10f5e26e8f..21f4581f76a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc @@ -53,7 +53,7 @@ struct FusedKernelMatcherPass }; bool IsActivationFunction(Operation *op) { - return isa(op) || isa(op) || isa(op); + return isa(op); } // Finds and returns an activation op that uses the result of `op`. If there are diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 07cc6203cbd..67a6c8dd6dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -96,7 +96,7 @@ class ResourceAnalyzer { } func.walk([&](Operation* op) { - if (isa(op) || isa(op)) { + if (isa(op)) { return; } if (auto assign_variable = dyn_cast(op)) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index af36770f496..961287b0b1f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -97,8 +97,7 @@ llvm::SmallSet GetCompositeResourceUserNames( // the error message are ordered. llvm::SmallSet composite_users; for (Operation* user : resource.getUsers()) - if (!llvm::isa(user) && - !llvm::isa(user)) + if (!llvm::isa(user)) composite_users.insert(user->getName().getStringRef()); return composite_users; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 15eb5593651..6eedfbbaf4b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -53,8 +53,8 @@ struct ReplicateToIslandPass // Returns whether op requires `_xla_replica_id` attribute. bool RequiresReplicaIDAttribute(Operation* op) { - return llvm::isa(op) || - llvm::isa(op); + return llvm::isa(op); } // Adds integer attribute that represents replica id for replicated ops that diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 799ab3a0f0d..2d30bbd1b93 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -140,7 +140,7 @@ struct ResourceOpLiftingPass // such nodes to carry information. void RemoveIdentity(Block* block) { for (auto& op : llvm::make_early_inc_range(*block)) { - if (isa(&op) || isa(&op)) { + if (isa(&op)) { op.replaceAllUsesWith(op.getOperands()); op.erase(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 8c537d01d5c..5907e72e602 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -114,14 +114,12 @@ Optional> InferShapeForFunctionReturnType(FuncOp func) { // Returns if the shape inference pass supports an op outside the TF dialect. bool IsSupportedNonTFOp(Operation* op) { - return isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op) || - isa(op) || - isa(op) || isa(op) || - isa(op); + return isa(op); } // Returns whether a cast back would need to be inserted, e.g., whether the diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index a9e1243714e..cbd24f8a815 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -440,7 +440,7 @@ llvm::SmallDenseMap> AccessedGradients( }; for (FuncOp func : funcs) { for (auto& op : func.front().getOperations()) { - if (llvm::isa(&op) || llvm::isa(&op)) { + if (llvm::isa(&op)) { op.replaceAllUsesWith(op.getOperands()); continue; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index b118ab6c6c9..11153f0dfc3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -640,7 +640,7 @@ LogicalResult DecomposeTensorListOpsInternal( decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { // TODO(yuanzx): Add a pass to remove identities in device computation. - if (llvm::isa(&op) || llvm::isa(&op)) { + if (llvm::isa(&op)) { op.replaceAllUsesWith(op.getOperands()); op.erase(); } else if (auto list = llvm::dyn_cast(&op)) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc index e2c439feeef..2a3f2197155 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc @@ -52,7 +52,7 @@ Operation* GetOpOfValue(Value value) { // TODO(b/158596585): Replace this with a cost model analysis. bool IsTrivialUnaryOperation(Operation* op) { - return llvm::isa(op) || llvm::isa(op); + return llvm::isa(op); } // Adds outside compilation attributes to unary ops such as Identity/Cast ops diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index b05e87c6485..1203eea2f84 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -67,7 +67,7 @@ void GetAdjacentXlaShardingOp(Operation* op, return; } - if (llvm::isa(op) || llvm::isa(op)) { + if (llvm::isa(op)) { for (auto user : op->getUsers()) GetAdjacentXlaShardingOp(user, sharding_op); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index d88982d9ee7..b8f55e3b979 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -127,7 +127,7 @@ Value SkipIdentity(Value v, bool allow_other_use, while (auto result = v.dyn_cast()) { if (!(allow_other_use || v.hasOneUse())) break; auto op = result.getDefiningOp(); - if (!llvm::isa(op) && !llvm::isa(op)) { + if (!llvm::isa(op)) { break; } v = op->getOperand(result.getResultNumber()); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index 7284626c46a..f09cf7b093e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -306,9 +306,8 @@ void BreakUpIslands::BreakUpIsland( llvm::dyn_cast(owner->getParentOp())) { (*new_control_inputs)[other_island_op].push_back(sink_island_control); } else if (owner->getDialect() == island_op.getDialect() && - !llvm::isa(owner) && - !llvm::isa(owner) && - !llvm::isa(owner)) { + !llvm::isa(owner)) { (*new_control_inputs)[owner].push_back(sink_island_control); } else { owner->emitOpError("adding control dependency not supported"); diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 7a576780c61..ff751a1f9f5 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -1060,7 +1060,7 @@ LogicalResult ConvertToHloModule::Lower( return success(); } - if (isa(inst) || isa(inst)) { + if (isa(inst)) { // Construct the return value for the function. If there are multiple // values returned, then create a tuple, else return value directly. xla::XlaOp return_value; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc index 36cf37e4044..1bac9a13553 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter.cc @@ -193,8 +193,7 @@ mlir::Operation* HoistAndFix(llvm::iplist::iterator begin_op, const bool any_op_is_loop_variant = [&] { for (mlir::Operation& op : llvm::make_range(begin_op, end_op)) { - if (mlir::isa(op) || - mlir::isa(op)) { + if (mlir::isa(op)) { return true; } } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 196ea218ef3..3d21379a624 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -174,8 +174,7 @@ struct DeadTempBufferRemoval for (auto result : op->getResults()) { if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) { // Store and Dealloc is OK. - if (llvm::isa(op) || - llvm::isa(op)) { + if (llvm::isa(op)) { return true; } // Load without uses is also ok. @@ -225,8 +224,8 @@ struct MoveScalarComputationsIntoGpuLaunch : mlir::PassWrapper { static bool isInliningBeneficiary(mlir::Operation* op) { - return llvm::isa(op) || llvm::isa(op) || - llvm::isa(op) || llvm::isa(op); + return llvm::isa(op); } static bool extractBeneficiaryOps(