From ae2b8a6dab670a3cf67c4b3ab770722ca84de0cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tar=C3=A9=20Gaskin?= Date: Sun, 26 Jul 2020 22:00:14 +0000 Subject: [PATCH] mlir directory resolutions --- .../mlir/tensorflow/ir/tf_executor.cc | 8 ++++--- .../mlir/tensorflow/transforms/einsum.cc | 6 ++--- .../transforms/promote_resources_to_args.cc | 5 ++-- .../transforms/resource_op_lifting.cc | 7 +++--- .../transforms/tpu_cluster_formation.cc | 6 +++-- .../tpu_merge_variables_with_execute.cc | 6 ++--- .../tpu_variable_runtime_reformatting.cc | 4 ++-- .../tensorflow/translate/export_graphdef.cc | 6 +++-- .../mlir/tensorflow/translate/import_model.cc | 6 +++-- .../translate/mlir_roundtrip_flags.cc | 2 +- .../mlir/tensorflow/utils/export_utils.cc | 4 ++-- .../utils/tpu_rewrite_device_util.cc | 6 +++-- .../tensorflow/utils/xla_sharding_util.cc | 5 ++-- .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 10 ++++---- .../mlir/xla/transforms/legalize_tf.cc | 23 +++++++++++-------- 15 files changed, 60 insertions(+), 44 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 8db06e83527..c18723b0982 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -190,14 +190,15 @@ LogicalResult Verify(GraphOp graph) { for (int i : llvm::seq(0, fetch.getNumOperands())) { Value operand = fetch.getOperand(i); // Break out of the loop at the first control operand encountered. + const int64_t num_results = graph.getNumResults(); if (operand.getType().isa()) { - if (i != graph.getNumResults()) + if (i != num_results) return fetch.emitOpError() << "operand #" << i << " is a control type, can't be bound to a graph result"; break; } - if (i >= graph.getNumResults()) + if (i >= num_results) return fetch.emitOpError() << "operand #" << i << " does not have a graph results to bind"; if (graph.getResult(i).getType() != operand.getType()) @@ -311,7 +312,8 @@ LogicalResult Verify(IslandOp island) { // Ensure that the yield terminator operands matches the island results type. int result_count = island.getNumResults() - 1; // -1 for the control token - if (yield.getNumOperands() != result_count) + const int num_operands = yield.getNumOperands(); + if (num_operands != result_count) return yield.emitOpError() << "has " << yield.getNumOperands() << " operand, but island returns " << result_count; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index c05a0ad1b62..69dab58c3f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -74,7 +74,7 @@ constexpr int kNumSupportedEquationVariables = 5; // A - E for now. bool tokenizeEquation(const llvm::StringRef& equation, std::vector* tokens) { std::map label_axis_mapping; - int index = 0; + size_t index = 0; int variable_count = 0; llvm::Regex r("[[:alpha:]]"); while (index < equation.size()) { @@ -177,7 +177,7 @@ TF::TransposeOp createTransposeOp(Value value, Location loc, auto perm_attr = DenseElementsAttr::get(perm_type, permutation); auto perm_op = rewriter->create(loc, perm_type, perm_attr); std::vector transposed_shape(shape.begin(), shape.end()); - for (int i = 0; i < shape.size(); ++i) { + for (int i = 0, end = shape.size(); i < end; ++i) { transposed_shape[i] = shape[permutation[i]]; } auto transposed_type = @@ -197,7 +197,7 @@ TF::SumOp createSumOp(Value value, Location loc, auto redux_op = rewriter->create(loc, redux_type, redux_attr); std::vector sum_shape(shape.size() - redux_axes.size()); int count = 0; - for (int i = 0; i < shape.size(); ++i) { + for (int i = 0, end = shape.size(); i < end; ++i) { if (std::find(redux_axes.begin(), redux_axes.end(), i) == redux_axes.end()) { sum_shape[count] = shape[i]; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index 961287b0b1f..4926dbaf4fb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -304,7 +304,7 @@ LogicalResult PromoteResourcesToArguments( continue; } - const auto index = resource_and_index.index(); + const long int index = resource_and_index.index(); const bool is_var_handle = index >= var_handles_start_idx; if (resource.write) { if (!is_var_handle || resource.read) { @@ -342,7 +342,8 @@ LogicalResult PromoteResourcesToArguments( } // Rewrite return if there are variable writes. - if (return_operands.size() > num_results_before) { + const int return_operands_size = return_operands.size(); + if (return_operands_size > num_results_before) { builder.create(return_op.getLoc(), return_operands); return_op.erase(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 74679f19941..9c4963ea1c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -656,7 +656,7 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { arg_data_type_and_updated_output_index); new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); // Replace uses. - for (int64_t i = 0; i < old_to_new_indices.size(); ++i) { + for (int64_t i = 0, end = old_to_new_indices.size(); i < end; ++i) { if (old_to_new_indices[i] >= 0) { while_op.getResult(i).replaceAllUsesWith( new_while.getResult(old_to_new_indices[i])); @@ -802,7 +802,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { AddLoadsStoresOutsideControlFlowOp(new_op, arg_data_type_and_updated_output_index); // Replace uses. - for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) { + for (int64_t i = 0, end = old_to_new_output_indices.size(); i < end; ++i) { if (old_to_new_output_indices[i] >= 0) { op.getResult(i).replaceAllUsesWith( new_op.getResult(old_to_new_output_indices[i])); @@ -946,7 +946,8 @@ void UpdatePartitionedCallOpWithNewCallee( AddLoadsStoresOutsideControlFlowOp( new_call, lifting_info.arg_data_type_and_updated_output_index); // Replace uses. - for (int64_t i = 0; i < lifting_info.old_to_new_output_indices.size(); ++i) { + for (int64_t i = 0, end = lifting_info.old_to_new_output_indices.size(); + i < end; ++i) { if (lifting_info.old_to_new_output_indices[i] >= 0) { call_op.getResult(i).replaceAllUsesWith( new_call.getResult(lifting_info.old_to_new_output_indices[i])); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index 9abf67b62a9..162ecd77d4f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -344,8 +344,9 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) { auto input = pos_and_input.value(); bool is_packed = llvm::cast(input).is_packed(); + const int num_operands = input->getNumOperands(); int num_inputs = is_packed ? 1 : num_replicas; - if (input->getNumOperands() != num_inputs) + if (num_operands != num_inputs) return input->emitOpError() << "requires " << num_inputs << " operands"; auto tpu_replicated_input = llvm::cast(input); @@ -393,7 +394,8 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { << "requires output of " << cluster.getOperationName() << " to lead to a 'tf.TPUReplicatedOutput' op"; - if (def->getNumResults() != num_replicas) + const int def_NumResults = def->getNumResults(); + if (def_NumResults != num_replicas) return def->emitOpError() << "requires " << num_replicas << " results"; auto replicate_outputs = llvm::make_range( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc index 3fd0dcd5a67..52c9287b619 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc @@ -298,7 +298,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( // Populate infos.old_to_new_output_mapping. int new_output_index = 0; infos.old_to_new_output_mapping.resize(execute_launch.getNumResults()); - for (int i = 0; i < execute_launch.getNumResults(); ++i) { + for (int i = 0, end = execute_launch.getNumResults(); i < end; ++i) { if (output_fused[i]) { infos.old_to_new_output_mapping[i] = -1; } else { @@ -375,7 +375,7 @@ void ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute, // Replace the uses of the original parallel_execute for the region containing // the merged execute. auto old_region_results = parallel_execute.GetRegionOutputs(region_index); - for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) { + for (int i = 0, end = infos.old_to_new_output_mapping.size(); i < end; ++i) { if (infos.old_to_new_output_mapping[i] < 0) continue; old_region_results[i].replaceAllUsesWith(new_parallel_execute_op->getResult( infos.old_to_new_output_mapping[i] + num_results_before_region)); @@ -407,7 +407,7 @@ void ReplaceExecute(tf_device::LaunchOp execute_launch, tf_device::LaunchOp merged_execute_launch, const VariableAccessesForTPUExecute& infos) { // Replace the uses. - for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) { + for (int i = 0, end = infos.old_to_new_output_mapping.size(); i < end; ++i) { if (infos.old_to_new_output_mapping[i] < 0) continue; execute_launch.getResult(i).replaceAllUsesWith( merged_execute_launch.getResult(infos.old_to_new_output_mapping[i])); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 12ce8c57f73..b33d37116cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -351,7 +351,7 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body, cond.setType(FunctionType::get(append_types(cond.getType().getInputs()), cond.getType().getResults(), cond.getContext())); - for (int64_t i = 0; i < state_vars.size(); ++i) { + for (int64_t i = 0, end = state_vars.size(); i < end; ++i) { int64_t arg_index = body.getNumArguments() - state_vars.size() + i; TF::VarHandleOp state_var = state_vars[i]; auto device_attr = state_var.getAttr(kDeviceAttr); @@ -368,7 +368,7 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body, if (new_while_op.output_shapes().size() != 0) { auto new_output_shapes = llvm::to_vector<4>(new_while_op.output_shapes()); // VarHandleOp is a scalar shape resource. - for (int64_t i = 0; i < state_vars.size(); ++i) { + for (int64_t i = 0, end = state_vars.size(); i < end; ++i) { new_output_shapes.push_back( mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef())); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 7983dfe0065..e508f8fbd6b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -511,14 +511,16 @@ StatusOr> Exporter::Convert( // generate unique names. if (!output_names.empty()) { const int num_data_results = graph_op.getNumResults(); - TF_RET_CHECK(output_names.size() == num_data_results) + const int64 output_names_size = output_names.size(); + TF_RET_CHECK(output_names_size == num_data_results) << "output names (" << output_names.size() << ") != terminator operands (" << num_data_results << ")"; llvm::DenseMap output_op_to_name; llvm::StringMap name_to_op; for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) { // Skip control rets. - if (it.index() >= num_data_results) break; + const int64 it_index = it.index(); + if (it_index >= num_data_results) break; // TODO(jpienaar): If there is a result index specified, ensure only one // and that it matches the result index of the op. std::string orig_name(output_names[it.index()]); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index a12378b66ba..2bf2c900cd2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -2387,7 +2387,8 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( "' is missing attribute 'index'"); auto index = attr->i(); - if (nodes->size() < index + 1) nodes->resize(index + 1); + const int nodes_size = nodes->size(); + if (nodes_size < index + 1) nodes->resize(index + 1); if ((*nodes)[index].node != nullptr) return errors::InvalidArgument(node->type_string(), " node '", @@ -3108,7 +3109,8 @@ Status CreateSavedModelIR( TF_ASSIGN_OR_RETURN(auto input_index_paths, input_linearizer.GetLeafIndexPaths( error_context + "in input signature: ")); - if (bound_input_base != input_index_paths.size()) { + const int input_index_paths_size = input_index_paths.size(); + if (bound_input_base != input_index_paths_size) { return errors::InvalidArgument( error_context, "Argument mismatch between concrete function input signature " diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc index 4640cb6ce64..f6d370ca604 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc @@ -94,7 +94,7 @@ Status ParseInputArrayInfo(const std::vector& node_names, } // StringMap doesn't support reserve else reserve input map size here. - for (int i = 0; i < node_names.size(); i++) { + for (int i = 0, end = node_names.size(); i < end; i++) { auto& name = node_names[i]; if (name.empty()) continue; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 7e018966396..0364b935b92 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -82,7 +82,7 @@ Status ConvertLocation(mlir::Location inst_loc, if (locations.size() <= 1) return errors::InvalidArgument("expected experimental debuf info."); // skip the first one, which is the name of the node_def. - for (int i = 0; i < locations.size() - 1; ++i) { + for (int i = 0, end = locations.size() - 1; i < end; ++i) { TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info)); } } @@ -518,7 +518,7 @@ Status SetSizeAttribute(absl::string_view name, size_t size, // This should be extremely rare as it means we are adding the same // attribute multiple times/have some redundancy in representing this // attribute. - int64 actual_size = result.first->second.i(); + size_t actual_size = result.first->second.i(); // Just check via string output as we shouldn't get here and if we do they // should be trivially the same, else fail. if (actual_size != size) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index f884b75bce1..843d491c330 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -149,7 +149,8 @@ Status GetTPUDevices( std::next(system_devices.begin()), system_devices.end())) { auto host_tpu_devices = lookup(device_spec); // Check number of TPU devices per host all match. - if (num_tpus_per_host != host_tpu_devices.size()) + const int64 host_tpu_devices_size = host_tpu_devices.size(); + if (num_tpus_per_host != host_tpu_devices_size) return errors::InvalidArgument( "expected the number of TPU devices per host to be ", num_tpus_per_host, ", got ", host_tpu_devices.size()); @@ -354,7 +355,8 @@ GetGeneralTPUExecutionDeviceAssignment( const int expected_device_assignment_size = num_replicas * num_cores_per_replica * kTPUTopologyRank; - if (device_assignment_attr.size() != expected_device_assignment_size) + const int device_assignment_attr_size = device_assignment_attr.size(); + if (device_assignment_attr_size != expected_device_assignment_size) return errors::InvalidArgument( "length of '", kDeviceAssignmentAttr, "' must be 'num_replicas' * 'num_cores_per_replica' * ", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index 083a5abf840..f662005f8a3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -242,7 +242,8 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs); if (mlir::failed(result)) return mlir::failure(); - if (tiled_inputs.size() != num_cores_per_replica) + const int64 tiled_inputs_size = tiled_inputs.size(); + if (tiled_inputs_size != num_cores_per_replica) cluster_func.emitError(llvm::formatv( "incorrect {0}-th tiled input sharding received. " "Product of tile sharding splits({1}) must be equal to " @@ -376,7 +377,7 @@ mlir::LogicalResult HandleTileShardedOutputs( llvm::SmallVector new_outputs; new_outputs.reserve(num_splits); - for (int i = 0; i < outputs_to_merge.size(); i = i + num_splits) { + for (int i = 0, end = outputs_to_merge.size(); i < end; i = i + num_splits) { mlir::TF::ConcatOp concat_op; auto result = CreateConcatOp(concat_dimension, location, diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index a4c3c43cfbf..7faac83a8de 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -170,8 +170,8 @@ static std::vector> Convert_source_target_pairs( static std::vector Convert_replica_groups( mlir::DenseIntElementsAttr groups) { - int64_t num_groups = groups.getType().getDimSize(0); - int64_t group_size = groups.getType().getDimSize(1); + uint64_t num_groups = groups.getType().getDimSize(0); + uint64_t group_size = groups.getType().getDimSize(1); std::vector result; result.reserve(num_groups); @@ -435,14 +435,14 @@ static void ExtractShardingsFromFunction( llvm::SmallVectorImpl>* ret_shardings) { arg_shardings->resize(function.getNumArguments(), absl::optional()); - for (int i = 0; i < function.getNumArguments(); ++i) + for (int i = 0, end = function.getNumArguments(); i < end; ++i) if (auto sharding = function.getArgAttrOfType(i, kShardingAttr)) (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); ret_shardings->resize(function.getNumResults(), absl::optional()); - for (int i = 0; i < function.getNumResults(); ++i) + for (int i = 0, end = function.getNumResults(); i < end; ++i) if (auto sharding = function.getResultAttrOfType(i, kShardingAttr)) (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); @@ -758,7 +758,7 @@ LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) { auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low()); auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high()); auto interior_padding = ConvertDenseIntAttr(op.interior_padding()); - for (xla::int64 i = 0; i < edge_padding_low.size(); ++i) { + for (xla::int64 i = 0, end = edge_padding_low.size(); i < end; ++i) { auto* dims = padding_config.add_dimensions(); dims->set_edge_padding_low(edge_padding_low[i]); dims->set_edge_padding_high(edge_padding_high[i]); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 4549386ce16..2a5f553240b 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -365,7 +365,7 @@ static Value UpdateSliceInMinorDims(Location loc, Value v, Value update, ArrayRef minor_starts, OpBuilder *builder) { llvm::SmallVector dus_starts(minor_starts.size()); - for (int64_t i = 0; i < minor_starts.size(); ++i) { + for (uint64_t i = 0; i < minor_starts.size(); ++i) { dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc, minor_starts[i], builder); } @@ -808,7 +808,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( values.reserve(shaped_type.getNumElements() / shape[1]); for (auto it : llvm::enumerate(int_attr.getIntValues())) { - if (it.index() % shape[1] == column) { + if ( static_cast(it.index() % shape[1]) == column) { values.push_back(it.value().getSExtValue()); } } @@ -1836,6 +1836,9 @@ Operation *AvgPoolDivideByCount( return result; } +Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); } +Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); } + // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window // dimensions with add as the reduction function. The reduction result is // then divided by the number of elements in the window. @@ -1846,8 +1849,9 @@ class ConvertAvgPoolOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { + Value input_value = GetAvgPoolInput(op); auto input_type = - op.value().getType().template dyn_cast(); + input_value.getType().template dyn_cast(); if (!input_type) return failure(); // We will do accumulation first; use a larger bitwidth if suitable. @@ -1862,8 +1866,6 @@ class ConvertAvgPoolOp : public OpRewritePattern { else result_type = UnrankedTensorType::get(sum_element_type); - Value input_value = op.value(); - // Convert if we need enlarge the element type's bitwidth. if (input_element_type != sum_element_type) input_value = rewriter.create(op.getLoc(), input_value, @@ -2680,7 +2682,7 @@ class ConvertSplitVOp : public OpRewritePattern { SmallVector slices; slices.reserve(op.getNumResults()); - for (int i = 0; i < op.getNumResults(); ++i) { + for (int i = 0, end = op.getNumResults(); i < end; ++i) { end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i]; slices.push_back(rewriter.create( op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter), @@ -2855,7 +2857,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // verifier. int64_t slicing_dim_size = op.begin().getType().cast().getShape()[0]; - auto input_rank = input_shape.size(); + const int input_rank = input_shape.size(); for (int d = slicing_dim_size; d < input_rank; ++d) { // We only support slicing major dimensions, so minor dimensions after // slicing dimensions are all sliced with their full sizes. @@ -2896,7 +2898,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { } // For non-slice dims, get the full slice of that dimension. - for (int d = slicing_dim_size; d < input_shape.size(); ++d) { + for (int d = slicing_dim_size, end = input_shape.size(); d < end; ++d) { slice_sizes.push_back(input_shape[d]); slice_begin_indices.push_back(zero); } @@ -3592,7 +3594,8 @@ class ConvertTileOp : public OpRewritePattern { multiples.getType().getRank() != 1) return failure(); - if (multiples.getNumElements() != input_shape.size()) return failure(); + const int64_t input_shape_size = input_shape.size(); + if (multiples.getNumElements() != input_shape_size) return failure(); SmallVector broadcasted_shape; SmallVector broadcast_dimensions; @@ -4379,7 +4382,7 @@ class ConvertUnpackOp : public OpRewritePattern { SmallVector results; results.reserve(op.getNumResults()); - for (int i = 0; i < op.getNumResults(); ++i) { + for (int i = 0, end = op.getNumResults(); i < end; ++i) { begin_indices[axis] = i; end_indices[axis] = i + 1;