Merge pull request #41750 from tg-at-google:wsign-compare-semi-final-mlir
PiperOrigin-RevId: 323913819 Change-Id: Ib6b2ea6b2cc4bf6acc17ae5ba7b07723fa79cb6b
This commit is contained in:
commit
76ca1e9eb3
tensorflow/compiler/mlir
tensorflow
ir
transforms
einsum.ccpromote_resources_to_args.ccresource_op_lifting.cctpu_cluster_formation.cctpu_merge_variables_with_execute.cctpu_variable_runtime_reformatting.cc
translate
utils
xla
@ -190,14 +190,15 @@ LogicalResult Verify(GraphOp graph) {
|
||||
for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
|
||||
Value operand = fetch.getOperand(i);
|
||||
// Break out of the loop at the first control operand encountered.
|
||||
const int64_t num_results = graph.getNumResults();
|
||||
if (operand.getType().isa<ControlType>()) {
|
||||
if (i != graph.getNumResults())
|
||||
if (i != num_results)
|
||||
return fetch.emitOpError()
|
||||
<< "operand #" << i
|
||||
<< " is a control type, can't be bound to a graph result";
|
||||
break;
|
||||
}
|
||||
if (i >= graph.getNumResults())
|
||||
if (i >= num_results)
|
||||
return fetch.emitOpError()
|
||||
<< "operand #" << i << " does not have a graph results to bind";
|
||||
if (graph.getResult(i).getType() != operand.getType())
|
||||
@ -311,7 +312,8 @@ LogicalResult Verify(IslandOp island) {
|
||||
|
||||
// Ensure that the yield terminator operands matches the island results type.
|
||||
int result_count = island.getNumResults() - 1; // -1 for the control token
|
||||
if (yield.getNumOperands() != result_count)
|
||||
const int num_operands = yield.getNumOperands();
|
||||
if (num_operands != result_count)
|
||||
return yield.emitOpError()
|
||||
<< "has " << yield.getNumOperands()
|
||||
<< " operand, but island returns " << result_count;
|
||||
|
@ -74,7 +74,7 @@ constexpr int kNumSupportedEquationVariables = 5; // A - E for now.
|
||||
bool tokenizeEquation(const llvm::StringRef& equation,
|
||||
std::vector<EquationToken>* tokens) {
|
||||
std::map<char, EquationToken> label_axis_mapping;
|
||||
int index = 0;
|
||||
size_t index = 0;
|
||||
int variable_count = 0;
|
||||
llvm::Regex r("[[:alpha:]]");
|
||||
while (index < equation.size()) {
|
||||
@ -177,7 +177,7 @@ TF::TransposeOp createTransposeOp(Value value, Location loc,
|
||||
auto perm_attr = DenseElementsAttr::get(perm_type, permutation);
|
||||
auto perm_op = rewriter->create<ConstantOp>(loc, perm_type, perm_attr);
|
||||
std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
for (int i = 0, end = shape.size(); i < end; ++i) {
|
||||
transposed_shape[i] = shape[permutation[i]];
|
||||
}
|
||||
auto transposed_type =
|
||||
@ -197,7 +197,7 @@ TF::SumOp createSumOp(Value value, Location loc,
|
||||
auto redux_op = rewriter->create<ConstantOp>(loc, redux_type, redux_attr);
|
||||
std::vector<int64_t> sum_shape(shape.size() - redux_axes.size());
|
||||
int count = 0;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
for (int i = 0, end = shape.size(); i < end; ++i) {
|
||||
if (std::find(redux_axes.begin(), redux_axes.end(), i) ==
|
||||
redux_axes.end()) {
|
||||
sum_shape[count] = shape[i];
|
||||
|
@ -304,7 +304,7 @@ LogicalResult PromoteResourcesToArguments(
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto index = resource_and_index.index();
|
||||
const int64_t index = resource_and_index.index();
|
||||
const bool is_var_handle = index >= var_handles_start_idx;
|
||||
if (resource.write) {
|
||||
if (!is_var_handle || resource.read) {
|
||||
@ -342,7 +342,8 @@ LogicalResult PromoteResourcesToArguments(
|
||||
}
|
||||
|
||||
// Rewrite return if there are variable writes.
|
||||
if (return_operands.size() > num_results_before) {
|
||||
const int return_operands_size = return_operands.size();
|
||||
if (return_operands_size > num_results_before) {
|
||||
builder.create<ReturnOp>(return_op.getLoc(), return_operands);
|
||||
return_op.erase();
|
||||
}
|
||||
|
@ -648,7 +648,7 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
AddLoadsStoresOutsideControlFlowOp(new_while,
|
||||
arg_data_type_and_updated_output_index);
|
||||
// Replace uses.
|
||||
for (int64_t i = 0; i < old_to_new_indices.size(); ++i) {
|
||||
for (int64_t i = 0, end = old_to_new_indices.size(); i < end; ++i) {
|
||||
if (old_to_new_indices[i] >= 0) {
|
||||
while_op.getResult(i).replaceAllUsesWith(
|
||||
new_while.getResult(old_to_new_indices[i]));
|
||||
@ -794,7 +794,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<FuncOp> branches) {
|
||||
AddLoadsStoresOutsideControlFlowOp(new_op,
|
||||
arg_data_type_and_updated_output_index);
|
||||
// Replace uses.
|
||||
for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) {
|
||||
for (int64_t i = 0, end = old_to_new_output_indices.size(); i < end; ++i) {
|
||||
if (old_to_new_output_indices[i] >= 0) {
|
||||
op.getResult(i).replaceAllUsesWith(
|
||||
new_op.getResult(old_to_new_output_indices[i]));
|
||||
@ -938,7 +938,8 @@ void UpdatePartitionedCallOpWithNewCallee(
|
||||
AddLoadsStoresOutsideControlFlowOp(
|
||||
new_call, lifting_info.arg_data_type_and_updated_output_index);
|
||||
// Replace uses.
|
||||
for (int64_t i = 0; i < lifting_info.old_to_new_output_indices.size(); ++i) {
|
||||
for (int64_t i = 0, end = lifting_info.old_to_new_output_indices.size();
|
||||
i < end; ++i) {
|
||||
if (lifting_info.old_to_new_output_indices[i] >= 0) {
|
||||
call_op.getResult(i).replaceAllUsesWith(
|
||||
new_call.getResult(lifting_info.old_to_new_output_indices[i]));
|
||||
|
@ -344,8 +344,9 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
||||
for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
|
||||
auto input = pos_and_input.value();
|
||||
bool is_packed = llvm::cast<TF::TPUReplicatedInputOp>(input).is_packed();
|
||||
const int num_operands = input->getNumOperands();
|
||||
int num_inputs = is_packed ? 1 : num_replicas;
|
||||
if (input->getNumOperands() != num_inputs)
|
||||
if (num_operands != num_inputs)
|
||||
return input->emitOpError() << "requires " << num_inputs << " operands";
|
||||
|
||||
auto tpu_replicated_input = llvm::cast<TF::TPUReplicatedInputOp>(input);
|
||||
@ -393,7 +394,8 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
||||
<< "requires output of " << cluster.getOperationName()
|
||||
<< " to lead to a 'tf.TPUReplicatedOutput' op";
|
||||
|
||||
if (def->getNumResults() != num_replicas)
|
||||
const int def_NumResults = def->getNumResults();
|
||||
if (def_NumResults != num_replicas)
|
||||
return def->emitOpError() << "requires " << num_replicas << " results";
|
||||
|
||||
auto replicate_outputs = llvm::make_range(
|
||||
|
@ -298,7 +298,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
||||
// Populate infos.old_to_new_output_mapping.
|
||||
int new_output_index = 0;
|
||||
infos.old_to_new_output_mapping.resize(execute_launch.getNumResults());
|
||||
for (int i = 0; i < execute_launch.getNumResults(); ++i) {
|
||||
for (int i = 0, end = execute_launch.getNumResults(); i < end; ++i) {
|
||||
if (output_fused[i]) {
|
||||
infos.old_to_new_output_mapping[i] = -1;
|
||||
} else {
|
||||
@ -375,7 +375,7 @@ void ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,
|
||||
// Replace the uses of the original parallel_execute for the region containing
|
||||
// the merged execute.
|
||||
auto old_region_results = parallel_execute.GetRegionOutputs(region_index);
|
||||
for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) {
|
||||
for (int i = 0, end = infos.old_to_new_output_mapping.size(); i < end; ++i) {
|
||||
if (infos.old_to_new_output_mapping[i] < 0) continue;
|
||||
old_region_results[i].replaceAllUsesWith(new_parallel_execute_op->getResult(
|
||||
infos.old_to_new_output_mapping[i] + num_results_before_region));
|
||||
@ -407,7 +407,7 @@ void ReplaceExecute(tf_device::LaunchOp execute_launch,
|
||||
tf_device::LaunchOp merged_execute_launch,
|
||||
const VariableAccessesForTPUExecute& infos) {
|
||||
// Replace the uses.
|
||||
for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) {
|
||||
for (int i = 0, end = infos.old_to_new_output_mapping.size(); i < end; ++i) {
|
||||
if (infos.old_to_new_output_mapping[i] < 0) continue;
|
||||
execute_launch.getResult(i).replaceAllUsesWith(
|
||||
merged_execute_launch.getResult(infos.old_to_new_output_mapping[i]));
|
||||
|
@ -351,7 +351,7 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body,
|
||||
cond.setType(FunctionType::get(append_types(cond.getType().getInputs()),
|
||||
cond.getType().getResults(),
|
||||
cond.getContext()));
|
||||
for (int64_t i = 0; i < state_vars.size(); ++i) {
|
||||
for (int64_t i = 0, end = state_vars.size(); i < end; ++i) {
|
||||
int64_t arg_index = body.getNumArguments() - state_vars.size() + i;
|
||||
TF::VarHandleOp state_var = state_vars[i];
|
||||
auto device_attr = state_var.getAttr(kDeviceAttr);
|
||||
|
@ -511,17 +511,19 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
// generate unique names.
|
||||
if (!output_names.empty()) {
|
||||
const int num_data_results = graph_op.getNumResults();
|
||||
TF_RET_CHECK(output_names.size() == num_data_results)
|
||||
const int64 output_names_size = output_names.size();
|
||||
TF_RET_CHECK(output_names_size == num_data_results)
|
||||
<< "output names (" << output_names.size()
|
||||
<< ") != terminator operands (" << num_data_results << ")";
|
||||
llvm::DenseMap<Operation*, llvm::StringRef> output_op_to_name;
|
||||
llvm::StringMap<Operation*> name_to_op;
|
||||
for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) {
|
||||
// Skip control rets.
|
||||
if (it.index() >= num_data_results) break;
|
||||
const int64 index = it.index();
|
||||
if (index >= num_data_results) break;
|
||||
// TODO(jpienaar): If there is a result index specified, ensure only one
|
||||
// and that it matches the result index of the op.
|
||||
std::string orig_name(output_names[it.index()]);
|
||||
std::string orig_name(output_names[index]);
|
||||
auto tensor_id = ParseTensorName(orig_name);
|
||||
auto name = LegalizeNodeName(
|
||||
llvm::StringRef(tensor_id.node().data(), tensor_id.node().size()));
|
||||
|
@ -2364,7 +2364,8 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph(
|
||||
"' is missing attribute 'index'");
|
||||
|
||||
auto index = attr->i();
|
||||
if (nodes->size() < index + 1) nodes->resize(index + 1);
|
||||
const int num_nodes = nodes->size();
|
||||
if (num_nodes < index + 1) nodes->resize(index + 1);
|
||||
|
||||
if ((*nodes)[index].node != nullptr)
|
||||
return errors::InvalidArgument(node->type_string(), " node '",
|
||||
@ -3085,7 +3086,8 @@ Status CreateSavedModelIR(
|
||||
TF_ASSIGN_OR_RETURN(auto input_index_paths,
|
||||
input_linearizer.GetLeafIndexPaths(
|
||||
error_context + "in input signature: "));
|
||||
if (bound_input_base != input_index_paths.size()) {
|
||||
const int input_index_paths_size = input_index_paths.size();
|
||||
if (bound_input_base != input_index_paths_size) {
|
||||
return errors::InvalidArgument(
|
||||
error_context,
|
||||
"Argument mismatch between concrete function input signature "
|
||||
|
@ -94,7 +94,7 @@ Status ParseInputArrayInfo(const std::vector<string>& node_names,
|
||||
}
|
||||
|
||||
// StringMap doesn't support reserve else reserve input map size here.
|
||||
for (int i = 0; i < node_names.size(); i++) {
|
||||
for (int i = 0, end = node_names.size(); i < end; i++) {
|
||||
auto& name = node_names[i];
|
||||
if (name.empty()) continue;
|
||||
|
||||
|
@ -82,7 +82,7 @@ Status ConvertLocation(mlir::Location inst_loc,
|
||||
if (locations.size() <= 1)
|
||||
return errors::InvalidArgument("expected experimental debuf info.");
|
||||
// skip the first one, which is the name of the node_def.
|
||||
for (int i = 0; i < locations.size() - 1; ++i) {
|
||||
for (int i = 0, end = locations.size() - 1; i < end; ++i) {
|
||||
TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info));
|
||||
}
|
||||
}
|
||||
@ -518,7 +518,7 @@ Status SetSizeAttribute(absl::string_view name, size_t size,
|
||||
// This should be extremely rare as it means we are adding the same
|
||||
// attribute multiple times/have some redundancy in representing this
|
||||
// attribute.
|
||||
int64 actual_size = result.first->second.i();
|
||||
size_t actual_size = result.first->second.i();
|
||||
// Just check via string output as we shouldn't get here and if we do they
|
||||
// should be trivially the same, else fail.
|
||||
if (actual_size != size)
|
||||
|
@ -149,7 +149,8 @@ Status GetTPUDevices(
|
||||
std::next(system_devices.begin()), system_devices.end())) {
|
||||
auto host_tpu_devices = lookup(device_spec);
|
||||
// Check number of TPU devices per host all match.
|
||||
if (num_tpus_per_host != host_tpu_devices.size())
|
||||
const int64 host_tpu_devices_size = host_tpu_devices.size();
|
||||
if (num_tpus_per_host != host_tpu_devices_size)
|
||||
return errors::InvalidArgument(
|
||||
"expected the number of TPU devices per host to be ",
|
||||
num_tpus_per_host, ", got ", host_tpu_devices.size());
|
||||
@ -354,7 +355,8 @@ GetGeneralTPUExecutionDeviceAssignment(
|
||||
|
||||
const int expected_device_assignment_size =
|
||||
num_replicas * num_cores_per_replica * kTPUTopologyRank;
|
||||
if (device_assignment_attr.size() != expected_device_assignment_size)
|
||||
const int device_assignment_attr_size = device_assignment_attr.size();
|
||||
if (device_assignment_attr_size != expected_device_assignment_size)
|
||||
return errors::InvalidArgument(
|
||||
"length of '", kDeviceAssignmentAttr,
|
||||
"' must be 'num_replicas' * 'num_cores_per_replica' * ",
|
||||
|
@ -242,7 +242,8 @@ mlir::LogicalResult ExtractInputsForLogicalDevices(
|
||||
cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs);
|
||||
if (mlir::failed(result)) return mlir::failure();
|
||||
|
||||
if (tiled_inputs.size() != num_cores_per_replica)
|
||||
const int64 tiled_inputs_size = tiled_inputs.size();
|
||||
if (tiled_inputs_size != num_cores_per_replica)
|
||||
cluster_func.emitError(llvm::formatv(
|
||||
"incorrect {0}-th tiled input sharding received. "
|
||||
"Product of tile sharding splits({1}) must be equal to "
|
||||
@ -376,7 +377,8 @@ mlir::LogicalResult HandleTileShardedOutputs(
|
||||
|
||||
llvm::SmallVector<mlir::Value, 4> new_outputs;
|
||||
new_outputs.reserve(num_splits);
|
||||
for (int i = 0; i < outputs_to_merge.size(); i = i + num_splits) {
|
||||
for (int i = 0, end = outputs_to_merge.size(); i < end;
|
||||
i = i + num_splits) {
|
||||
mlir::TF::ConcatOp concat_op;
|
||||
auto result =
|
||||
CreateConcatOp(concat_dimension, location,
|
||||
|
@ -169,8 +169,8 @@ static std::vector<std::pair<int64, int64>> Convert_source_target_pairs(
|
||||
|
||||
static std::vector<xla::ReplicaGroup> Convert_replica_groups(
|
||||
mlir::DenseIntElementsAttr groups) {
|
||||
int64_t num_groups = groups.getType().getDimSize(0);
|
||||
int64_t group_size = groups.getType().getDimSize(1);
|
||||
uint64_t num_groups = groups.getType().getDimSize(0);
|
||||
uint64_t group_size = groups.getType().getDimSize(1);
|
||||
|
||||
std::vector<xla::ReplicaGroup> result;
|
||||
result.reserve(num_groups);
|
||||
@ -434,14 +434,14 @@ static void ExtractShardingsFromFunction(
|
||||
llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) {
|
||||
arg_shardings->resize(function.getNumArguments(),
|
||||
absl::optional<xla::OpSharding>());
|
||||
for (int i = 0; i < function.getNumArguments(); ++i)
|
||||
for (int i = 0, end = function.getNumArguments(); i < end; ++i)
|
||||
if (auto sharding =
|
||||
function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
|
||||
(*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
|
||||
|
||||
ret_shardings->resize(function.getNumResults(),
|
||||
absl::optional<xla::OpSharding>());
|
||||
for (int i = 0; i < function.getNumResults(); ++i)
|
||||
for (int i = 0, end = function.getNumResults(); i < end; ++i)
|
||||
if (auto sharding =
|
||||
function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
|
||||
(*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
|
||||
@ -757,7 +757,7 @@ LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) {
|
||||
auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low());
|
||||
auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high());
|
||||
auto interior_padding = ConvertDenseIntAttr(op.interior_padding());
|
||||
for (xla::int64 i = 0; i < edge_padding_low.size(); ++i) {
|
||||
for (xla::int64 i = 0, end = edge_padding_low.size(); i < end; ++i) {
|
||||
auto* dims = padding_config.add_dimensions();
|
||||
dims->set_edge_padding_low(edge_padding_low[i]);
|
||||
dims->set_edge_padding_high(edge_padding_high[i]);
|
||||
|
@ -365,7 +365,7 @@ static Value UpdateSliceInMinorDims(Location loc, Value v, Value update,
|
||||
ArrayRef<int64_t> minor_starts,
|
||||
OpBuilder *builder) {
|
||||
llvm::SmallVector<Value, 4> dus_starts(minor_starts.size());
|
||||
for (int64_t i = 0; i < minor_starts.size(); ++i) {
|
||||
for (uint64_t i = 0; i < minor_starts.size(); ++i) {
|
||||
dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc,
|
||||
minor_starts[i], builder);
|
||||
}
|
||||
@ -808,7 +808,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D(
|
||||
values.reserve(shaped_type.getNumElements() / shape[1]);
|
||||
|
||||
for (auto it : llvm::enumerate(int_attr.getIntValues())) {
|
||||
if (it.index() % shape[1] == column) {
|
||||
if (static_cast<int>(it.index() % shape[1]) == column) {
|
||||
values.push_back(it.value().getSExtValue());
|
||||
}
|
||||
}
|
||||
@ -2945,7 +2945,7 @@ class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
|
||||
SmallVector<Value, 4> slices;
|
||||
slices.reserve(op.getNumResults());
|
||||
|
||||
for (int i = 0; i < op.getNumResults(); ++i) {
|
||||
for (int i = 0, end = op.getNumResults(); i < end; ++i) {
|
||||
end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
|
||||
slices.push_back(rewriter.create<mhlo::SliceOp>(
|
||||
op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
|
||||
@ -3120,7 +3120,7 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
|
||||
// verifier.
|
||||
int64_t slicing_dim_size =
|
||||
op.begin().getType().cast<RankedTensorType>().getShape()[0];
|
||||
auto input_rank = input_shape.size();
|
||||
const int input_rank = input_shape.size();
|
||||
for (int d = slicing_dim_size; d < input_rank; ++d) {
|
||||
// We only support slicing major dimensions, so minor dimensions after
|
||||
// slicing dimensions are all sliced with their full sizes.
|
||||
@ -3161,7 +3161,7 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
|
||||
}
|
||||
|
||||
// For non-slice dims, get the full slice of that dimension.
|
||||
for (int d = slicing_dim_size; d < input_shape.size(); ++d) {
|
||||
for (int d = slicing_dim_size, end = input_shape.size(); d < end; ++d) {
|
||||
slice_sizes.push_back(input_shape[d]);
|
||||
slice_begin_indices.push_back(zero);
|
||||
}
|
||||
@ -3857,7 +3857,8 @@ class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
|
||||
multiples.getType().getRank() != 1)
|
||||
return failure();
|
||||
|
||||
if (multiples.getNumElements() != input_shape.size()) return failure();
|
||||
const int64_t input_shape_size = input_shape.size();
|
||||
if (multiples.getNumElements() != input_shape_size) return failure();
|
||||
|
||||
SmallVector<int64_t, 8> broadcasted_shape;
|
||||
SmallVector<int64_t, 4> broadcast_dimensions;
|
||||
@ -4644,7 +4645,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
|
||||
SmallVector<Value, 4> results;
|
||||
results.reserve(op.getNumResults());
|
||||
|
||||
for (int i = 0; i < op.getNumResults(); ++i) {
|
||||
for (int i = 0, end = op.getNumResults(); i < end; ++i) {
|
||||
begin_indices[axis] = i;
|
||||
end_indices[axis] = i + 1;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user