Use OpState::operator->() to get to member functions in Operation so we can remove the corresponding methods from OpState.
PiperOrigin-RevId: 346865218 Change-Id: Iafe229bfc8577713031cb352903c9fa61cf077f2
This commit is contained in:
parent
6bdd30ec15
commit
10c7c897c8
@ -58,10 +58,10 @@ void AnnotateParameterReplication::runOnOperation() {
|
||||
ModuleOp m = getOperation();
|
||||
OpBuilder builder(m.getContext());
|
||||
m.walk([&](tf_device::ClusterFuncOp cluster_func) {
|
||||
auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
|
||||
auto replicate = cluster_func->getParentOfType<tf_device::ReplicateOp>();
|
||||
if (!replicate) return;
|
||||
auto mirrored_variable_indices_attr =
|
||||
replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
|
||||
replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
|
||||
llvm::SmallDenseSet<int64_t, 8> mirrored_replicate_args;
|
||||
if (mirrored_variable_indices_attr) {
|
||||
for (const auto& mirrored_index : mirrored_variable_indices_attr) {
|
||||
|
||||
@ -291,7 +291,7 @@ class ClusterTFOpsByHostPass
|
||||
void runOnFunction() override {
|
||||
MLIRContext *context = &getContext();
|
||||
FuncOp func_op = getOperation();
|
||||
ModuleOp module_op = func_op.getParentOfType<mlir::ModuleOp>();
|
||||
ModuleOp module_op = func_op->getParentOfType<mlir::ModuleOp>();
|
||||
|
||||
llvm::Optional<llvm::StringMap<FunctionMetadata>> metadatas =
|
||||
GetFunctionMetadatas(func_op);
|
||||
|
||||
@ -42,7 +42,7 @@ void ConstantOpDeviceAssignmentPass::runOnOperation() {
|
||||
|
||||
module.walk([&](TF::ConstOp op) {
|
||||
// Keep the ConstOp if the op already have the device attribute.
|
||||
if (StringAttr device_attr = op.getAttrOfType<StringAttr>(kDeviceAttr)) {
|
||||
if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
|
||||
return WalkResult::advance();
|
||||
}
|
||||
OpBuilder builder(op);
|
||||
|
||||
@ -66,7 +66,7 @@ struct ExecutorIslandCoarsening
|
||||
// that is closest to the island in the graph. If no candidate can be found or
|
||||
// the op found is not an island, an empty optional is returned.
|
||||
llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
|
||||
Operation* graph_op = island.getParentOp();
|
||||
Operation* graph_op = island->getParentOp();
|
||||
Operation* candidate = nullptr;
|
||||
|
||||
// Check island control operands.
|
||||
@ -95,7 +95,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
|
||||
// an op, that is closest to the island in the graph. If no candidate can be
|
||||
// found or the op found is not an island, an empty optional is returned.
|
||||
llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
|
||||
Operation* graph_op = island.getParentOp();
|
||||
Operation* graph_op = island->getParentOp();
|
||||
Operation* candidate = nullptr;
|
||||
|
||||
// Check island control results.
|
||||
|
||||
@ -78,7 +78,7 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
||||
// in a new module to run the V1 bridge there.
|
||||
SmallVector<IslandOp, 8> islands_to_outline;
|
||||
getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) {
|
||||
auto island_op = cast<IslandOp>(replicate_op.getParentOp());
|
||||
auto island_op = cast<IslandOp>(replicate_op->getParentOp());
|
||||
if (!island_op || island_op.WrapsSingleOp()) return;
|
||||
islands_to_outline.push_back(island_op);
|
||||
});
|
||||
|
||||
@ -40,7 +40,7 @@ namespace {
|
||||
// "tf.entry_function" attribute defined.
|
||||
bool CanPruneGraph(FuncOp func) {
|
||||
return func.getName() != "main" ||
|
||||
func.getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
|
||||
func->getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
|
||||
}
|
||||
|
||||
// Visits an op's operand if it is an output of an Operation in the same
|
||||
|
||||
@ -124,7 +124,7 @@ void LayoutAssignmentPass::runOnFunction() {
|
||||
|
||||
// Get runtime devices information from the closest parent module.
|
||||
RuntimeDevices devices;
|
||||
if (failed(::tensorflow::GetDevicesFromOp(func.getParentOfType<ModuleOp>(),
|
||||
if (failed(::tensorflow::GetDevicesFromOp(func->getParentOfType<ModuleOp>(),
|
||||
&devices)))
|
||||
return signalPassFailure();
|
||||
|
||||
|
||||
@ -264,7 +264,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() {
|
||||
// Only if `allow_soft_placement` attribute is true should we mark ops
|
||||
// for outside compilation.
|
||||
auto soft_placement_attr =
|
||||
cluster.getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
|
||||
cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
|
||||
if (!(soft_placement_attr && soft_placement_attr.getValue())) {
|
||||
return WalkResult::advance();
|
||||
}
|
||||
@ -281,7 +281,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() {
|
||||
// Only if `allow_soft_placement` attribute is true should we unmark ops
|
||||
// for outside compilation.
|
||||
auto soft_placement_attr =
|
||||
cluster.getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
|
||||
cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
|
||||
if (!(soft_placement_attr && soft_placement_attr.getValue())) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -165,7 +165,7 @@ void CreateIslandsFromParallelExecute(
|
||||
unused_execute_controls.push_back(execute.control());
|
||||
|
||||
if (!unused_execute_controls.empty()) {
|
||||
auto graph_op = island_op.getParentOfType<tf_executor::GraphOp>();
|
||||
auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
|
||||
tf_executor::FetchOp fetch = graph_op.GetFetch();
|
||||
auto fetches = llvm::to_vector<8>(fetch.getOperands());
|
||||
fetches.append(unused_execute_controls.begin(),
|
||||
|
||||
@ -138,7 +138,8 @@ void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
|
||||
ShapedType shaped_type =
|
||||
variable_v2_op.getResult().getType().cast<ShapedType>();
|
||||
TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
|
||||
StringAttr device_attr = variable_v2_op.getAttrOfType<StringAttr>("device");
|
||||
StringAttr device_attr =
|
||||
variable_v2_op->getAttrOfType<StringAttr>("device");
|
||||
if (!device_attr) device_attr = builder.getStringAttr("");
|
||||
StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
|
||||
if (variable_name.empty()) {
|
||||
|
||||
@ -210,8 +210,8 @@ using ArgMatcherFn = function_ref<bool(Value, Region&, Value, Region&)>;
|
||||
bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
|
||||
if (first.getNumOperands() != second.getNumOperands()) return false;
|
||||
|
||||
Region& first_region = *first.getParentRegion();
|
||||
Region& second_region = *second.getParentRegion();
|
||||
Region& first_region = *first->getParentRegion();
|
||||
Region& second_region = *second->getParentRegion();
|
||||
|
||||
for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
|
||||
// Get the defining Op, skipping over casts.
|
||||
|
||||
@ -316,7 +316,7 @@ void ReplicateToIslandPass::runOnFunction() {
|
||||
});
|
||||
|
||||
for (tf_executor::IslandOp island_op : replicate_op_islands) {
|
||||
auto graph_op = island_op.getParentOfType<tf_executor::GraphOp>();
|
||||
auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
|
||||
auto replicate_op =
|
||||
cast<tf_device::ReplicateOp>(island_op.GetBody().front());
|
||||
if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,
|
||||
|
||||
@ -1106,7 +1106,7 @@ LogicalResult HandlePartitionedCallOpCallee(
|
||||
|
||||
// Clone the callee before making changes.
|
||||
SmallString<64> name_base = callee.getName();
|
||||
auto module = callee.getParentOfType<ModuleOp>();
|
||||
auto module = callee->getParentOfType<ModuleOp>();
|
||||
name_base += "_resource_lifted";
|
||||
auto name = name_base;
|
||||
callee = callee.clone();
|
||||
@ -1376,7 +1376,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
|
||||
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
|
||||
lifted_partitioned_call_callees;
|
||||
if (failed(HoistForControlFlow(
|
||||
&function.front(), cast<ModuleOp>(function.getParentOp()),
|
||||
&function.front(), cast<ModuleOp>(function->getParentOp()),
|
||||
/*vars_initialized=*/false, &lifted_partitioned_call_callees)))
|
||||
return failure();
|
||||
|
||||
|
||||
@ -117,7 +117,7 @@ void EliminateUnusedResults(
|
||||
// multiple uses or unknown uses (for external functions). The cloned function
|
||||
// will be marked as private.
|
||||
FuncOp CloneFunctionIfNeeded(FuncOp func) {
|
||||
ModuleOp module = func.getParentOfType<ModuleOp>();
|
||||
ModuleOp module = func->getParentOfType<ModuleOp>();
|
||||
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||
if (func_uses.hasValue() && llvm::hasSingleElement(func_uses.getValue()))
|
||||
return func;
|
||||
|
||||
@ -247,7 +247,7 @@ bool CanInferTensorListElementType(Value tensorlist,
|
||||
continue;
|
||||
}
|
||||
if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
|
||||
Operation* parent = yield.getParentOp();
|
||||
Operation* parent = yield->getParentOp();
|
||||
if (!CanInferTensorListElementType(
|
||||
parent->getResult(use.getOperandNumber()), initial_element_shape,
|
||||
potential_element_type))
|
||||
@ -619,7 +619,7 @@ ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context,
|
||||
ArrayRef<FuncOp> ShapeInference::GetCallers(FuncOp fn) {
|
||||
auto pair = callers_of_func_.try_emplace(fn);
|
||||
if (pair.second) {
|
||||
ModuleOp module = fn.getParentOfType<ModuleOp>();
|
||||
ModuleOp module = fn->getParentOfType<ModuleOp>();
|
||||
auto uses = mlir::SymbolTable::getSymbolUses(fn.getOperation(), module);
|
||||
if (uses) {
|
||||
pair.first->second.reserve(pair.first->second.size());
|
||||
|
||||
@ -60,7 +60,7 @@ class TensorDeviceCopyConversionPass
|
||||
arg_device = attr;
|
||||
}
|
||||
|
||||
StringAttr op_device = op.getAttrOfType<StringAttr>(kDeviceAttr);
|
||||
StringAttr op_device = op->getAttrOfType<StringAttr>(kDeviceAttr);
|
||||
if (!op_device) op_device = empty_string;
|
||||
// Skip the folding logic if the argument's device is different from the
|
||||
// operation's device.
|
||||
|
||||
@ -116,7 +116,7 @@ void TPUColocateCompositeResourceOps::runOnFunction() {
|
||||
|
||||
OpBuilder builder(&getContext());
|
||||
for (auto execute_launch : execute_launches) {
|
||||
auto replicate = execute_launch.getParentOfType<tf_device::ReplicateOp>();
|
||||
auto replicate = execute_launch->getParentOfType<tf_device::ReplicateOp>();
|
||||
if (!replicate) continue;
|
||||
|
||||
ColocateCompositeResourceOpsInReplicate(replicate, &builder);
|
||||
|
||||
@ -109,7 +109,7 @@ bool IsSupportedInputOp(
|
||||
};
|
||||
|
||||
// Check all generator aliases (ops or function argument) are on CPU.
|
||||
FuncOp func = iterator_op.getParentOfType<FuncOp>();
|
||||
FuncOp func = iterator_op->getParentOfType<FuncOp>();
|
||||
return llvm::all_of(aliases, [&](Value alias) {
|
||||
// Ignore non-generator aliases.
|
||||
if (!is_generator(alias)) return true;
|
||||
@ -230,7 +230,7 @@ void HandleCompileAndExecutes(
|
||||
|
||||
bool metadata_updated = false;
|
||||
auto maybe_replicate =
|
||||
execute_launches.front().getParentOfType<tf_device::ReplicateOp>();
|
||||
execute_launches.front()->getParentOfType<tf_device::ReplicateOp>();
|
||||
|
||||
for (auto execute_and_input_mapping :
|
||||
llvm::zip(execute_launches, input_mappings)) {
|
||||
@ -284,7 +284,7 @@ void TPUDynamicLayoutPass::runOnFunction(
|
||||
func.walk([&](TF::_TPUCompileMlirOp compile) {
|
||||
// Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
|
||||
auto compile_launch =
|
||||
llvm::dyn_cast<tf_device::LaunchOp>(compile.getParentOp());
|
||||
llvm::dyn_cast<tf_device::LaunchOp>(compile->getParentOp());
|
||||
if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
|
||||
|
||||
llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
|
||||
@ -295,7 +295,7 @@ void TPUDynamicLayoutPass::runOnFunction(
|
||||
auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
|
||||
if (!execute) return;
|
||||
auto execute_launch =
|
||||
llvm::dyn_cast<tf_device::LaunchOp>(execute.getParentOp());
|
||||
llvm::dyn_cast<tf_device::LaunchOp>(execute->getParentOp());
|
||||
if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
|
||||
execute_launches.push_back(execute_launch);
|
||||
}
|
||||
|
||||
@ -180,7 +180,7 @@ void AnnotateFunctionArgumentsWithPaddings(
|
||||
|
||||
LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
|
||||
SymbolTable* symbol_table) {
|
||||
auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
|
||||
auto replicate = cluster_func->getParentOfType<tf_device::ReplicateOp>();
|
||||
// LaunchFunc is not replicated, there will be no padding.
|
||||
if (!replicate) return success();
|
||||
|
||||
@ -188,7 +188,7 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
|
||||
if (!func) return success();
|
||||
|
||||
auto replicated_input_indices_attr =
|
||||
replicate.getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
|
||||
replicate->getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
|
||||
if (!replicated_input_indices_attr) return success();
|
||||
|
||||
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices =
|
||||
|
||||
@ -131,7 +131,7 @@ llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
||||
const TF::SideEffectAnalysis& side_effect_analysis,
|
||||
tf_device::ClusterOp cluster) {
|
||||
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
|
||||
cluster.getParentOfType<FuncOp>());
|
||||
cluster->getParentOfType<FuncOp>());
|
||||
Region* cluster_region = &cluster.body();
|
||||
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
|
||||
|
||||
@ -227,7 +227,7 @@ void FindOutsideCompiledOpsAtTailAndClusterResults(
|
||||
llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
|
||||
llvm::SmallVectorImpl<Value>* cluster_results) {
|
||||
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
|
||||
cluster.getParentOfType<FuncOp>());
|
||||
cluster->getParentOfType<FuncOp>());
|
||||
Region* cluster_region = &cluster.body();
|
||||
llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
|
||||
Operation* terminator = cluster.GetBody().getTerminator();
|
||||
|
||||
@ -755,7 +755,7 @@ void MoveOutsideCompiledOps(
|
||||
// If there is no replication/data parallelism, it is assumed the device
|
||||
// ordinal is always 0 (e.g. /device:TPU:0). In that case, a constant 0
|
||||
// attribute can be used instead for _XlaSendFromHost/_XlaRecvAtHost ops.
|
||||
if (tpu_cluster.getParentOfType<tf_device::ReplicateOp>()) {
|
||||
if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
|
||||
auto device_ordinal_op =
|
||||
builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
|
||||
host_launch_op.getLoc(),
|
||||
|
||||
@ -127,7 +127,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
||||
VariableAccessesForTPUExecute infos;
|
||||
Attribute device_attr = execute_launch.deviceAttr();
|
||||
if (check_device && !device_attr) return infos;
|
||||
auto func = execute_launch.getParentOfType<mlir::FuncOp>();
|
||||
auto func = execute_launch->getParentOfType<mlir::FuncOp>();
|
||||
|
||||
// Track the first read op found, which is used later to check if there are
|
||||
// assign ops between it and the TPUExecute op. We will exclude reads before
|
||||
@ -137,7 +137,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
||||
Operation* first_read = nullptr;
|
||||
Operation& execute = execute_launch.GetBody().front();
|
||||
auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
|
||||
execute_launch.getParentOp());
|
||||
execute_launch->getParentOp());
|
||||
Operation* execute_parent =
|
||||
parallel_execute ? parallel_execute.getOperation() : execute_launch;
|
||||
// Find inputs that are variable reads.
|
||||
@ -148,7 +148,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
||||
operand.value().get().getDefiningOp());
|
||||
if (!read_op) continue;
|
||||
if (check_same_region &&
|
||||
read_op.getParentRegion() != execute_parent->getParentRegion())
|
||||
read_op->getParentRegion() != execute_parent->getParentRegion())
|
||||
continue;
|
||||
|
||||
auto resource = read_op.resource();
|
||||
@ -240,7 +240,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
||||
auto execute_outputs =
|
||||
parallel_execute
|
||||
? parallel_execute.GetRegionOutputs(
|
||||
execute_launch.getParentRegion()->getRegionNumber())
|
||||
execute_launch->getParentRegion()->getRegionNumber())
|
||||
: execute_launch.getResults();
|
||||
for (auto execute_output : llvm::enumerate(execute_outputs)) {
|
||||
// TODO(lyandy): Handle updates to resource writes by remapping to parent
|
||||
@ -340,7 +340,7 @@ void ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,
|
||||
llvm::SmallVector<Type, 8> output_types;
|
||||
const int parallel_execute_num_results = parallel_execute_op->getNumResults();
|
||||
output_types.reserve(parallel_execute_num_results);
|
||||
Region* execute_region = merged_execute_launch.getParentRegion();
|
||||
Region* execute_region = merged_execute_launch->getParentRegion();
|
||||
const int region_index = execute_region->getRegionNumber();
|
||||
const int num_results_before_region =
|
||||
AppendTypes(&output_types, parallel_execute, 0, region_index);
|
||||
@ -547,7 +547,7 @@ void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,
|
||||
merged_execute_launch.GetBody().getTerminator());
|
||||
|
||||
if (auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
|
||||
execute_launch.getParentOp()))
|
||||
execute_launch->getParentOp()))
|
||||
ReplaceParallelExecute(parallel_execute, execute_launch,
|
||||
merged_execute_launch, infos, builder);
|
||||
else
|
||||
@ -591,11 +591,11 @@ void TPUMergeVariablesWithExecutePass::runOnFunction() {
|
||||
for (auto execute_launch : execute_launches) {
|
||||
OpBuilder builder(&getContext());
|
||||
const bool parent_is_replicate =
|
||||
llvm::isa<tf_device::ReplicateOp>(execute_launch.getParentOp()) ||
|
||||
llvm::isa<tf_device::ReplicateOp>(execute_launch->getParentOp()) ||
|
||||
(llvm::isa<tf_device::ParallelExecuteOp>(
|
||||
execute_launch.getParentOp()) &&
|
||||
execute_launch->getParentOp()) &&
|
||||
llvm::isa<tf_device::ReplicateOp>(
|
||||
execute_launch.getParentOp()->getParentOp()));
|
||||
execute_launch->getParentOp()->getParentOp()));
|
||||
|
||||
// If this is inside a tf_device::ReplicateOp, the variables are guaranteed
|
||||
// to be on the same device as the TPUExecute op. Skip device checking in
|
||||
|
||||
@ -106,14 +106,14 @@ std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
|
||||
|
||||
LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
|
||||
std::string* serialized_func_module) {
|
||||
ModuleOp module = entry_func.getParentOfType<ModuleOp>();
|
||||
ModuleOp module = entry_func->getParentOfType<ModuleOp>();
|
||||
SymbolTable entry_module_table(module);
|
||||
llvm::SmallVector<FuncOp, 4> referenced({entry_func});
|
||||
|
||||
// Create a new module to hold func and all referenced functions.
|
||||
OwningModuleRef module_for_func =
|
||||
ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
|
||||
auto parent_module = entry_func.getParentOfType<ModuleOp>();
|
||||
auto parent_module = entry_func->getParentOfType<ModuleOp>();
|
||||
auto versions_attr = parent_module.getAttr(kVersionsAttr);
|
||||
if (!versions_attr)
|
||||
return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
|
||||
@ -165,7 +165,7 @@ LogicalResult SetMetadataProtoStepMarkerLocation(
|
||||
tf_device::ClusterFuncOp op,
|
||||
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
||||
auto step_marker_location =
|
||||
op.getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
|
||||
op->getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
|
||||
if (!step_marker_location)
|
||||
return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
|
||||
|
||||
@ -190,7 +190,7 @@ LogicalResult SetMetadataProtoStepMarkerLocation(
|
||||
LogicalResult SetMetadataProtoPaddingMap(
|
||||
tf_device::ClusterFuncOp op,
|
||||
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
||||
auto padding_map = op.getAttrOfType<ArrayAttr>(kPaddingMapAttr);
|
||||
auto padding_map = op->getAttrOfType<ArrayAttr>(kPaddingMapAttr);
|
||||
if (!padding_map)
|
||||
return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr));
|
||||
|
||||
@ -234,7 +234,7 @@ LogicalResult SetMetadataProtoArgs(
|
||||
tf_device::ClusterFuncOp op,
|
||||
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
||||
auto input_shardings =
|
||||
op.getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
|
||||
op->getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
|
||||
if (!input_shardings)
|
||||
return op.emitOpError(
|
||||
CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
|
||||
@ -289,7 +289,7 @@ LogicalResult SetMetadataProtoRetvals(
|
||||
tf_device::ClusterFuncOp op,
|
||||
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
||||
auto output_shardings =
|
||||
op.getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
|
||||
op->getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
|
||||
if (!output_shardings)
|
||||
return op.emitOpError(
|
||||
CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
|
||||
@ -329,7 +329,7 @@ LogicalResult SetMetadataProtoFromClusterFuncOp(
|
||||
if (xla_device_assignment.hasValue())
|
||||
*metadata->mutable_device_assignment() =
|
||||
std::move(xla_device_assignment.getValue());
|
||||
auto use_spmd_attr = op.getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
|
||||
auto use_spmd_attr = op->getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
|
||||
if (!use_spmd_attr)
|
||||
return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
|
||||
metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
|
||||
@ -400,7 +400,7 @@ Operation* BuildCompileOp(
|
||||
}
|
||||
|
||||
FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
|
||||
FuncOp func = cluster_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||
FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||
func_attr.getValue());
|
||||
|
||||
std::string txt_module;
|
||||
@ -637,16 +637,16 @@ LogicalResult Rewrite(
|
||||
OpBuilder* builder) {
|
||||
// Skip non-tpu device cluster_func.
|
||||
auto replicate_attr =
|
||||
cluster_func.getAttrOfType<StringAttr>("_tpu_replicate");
|
||||
cluster_func->getAttrOfType<StringAttr>("_tpu_replicate");
|
||||
if (!replicate_attr) return success();
|
||||
|
||||
// Collect `num_replicas` and `num_cores_per_replica` attributes.
|
||||
int num_replicas = 1;
|
||||
tf_device::ReplicateOp replicate =
|
||||
cluster_func.getParentOfType<tf_device::ReplicateOp>();
|
||||
cluster_func->getParentOfType<tf_device::ReplicateOp>();
|
||||
if (replicate) num_replicas = replicate.n();
|
||||
|
||||
auto num_cores_per_replica_attr = cluster_func.getAttrOfType<IntegerAttr>(
|
||||
auto num_cores_per_replica_attr = cluster_func->getAttrOfType<IntegerAttr>(
|
||||
tensorflow::kNumCoresPerReplicaAttr);
|
||||
if (!num_cores_per_replica_attr)
|
||||
return cluster_func.emitOpError(
|
||||
@ -655,12 +655,12 @@ LogicalResult Rewrite(
|
||||
int num_cores_per_replica = num_cores_per_replica_attr.getInt();
|
||||
|
||||
auto topology_attr =
|
||||
cluster_func.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
||||
cluster_func->getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
||||
if (!topology_attr)
|
||||
return cluster_func.emitOpError(
|
||||
CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
|
||||
|
||||
auto device_assignment_attr = cluster_func.getAttrOfType<mlir::ArrayAttr>(
|
||||
auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
|
||||
tensorflow::kDeviceAssignmentAttr);
|
||||
if (!device_assignment_attr)
|
||||
return cluster_func.emitOpError(
|
||||
@ -692,11 +692,11 @@ LogicalResult Rewrite(
|
||||
|
||||
// Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
|
||||
// parallel_execute region if it exists.
|
||||
if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func.getParentOp())) {
|
||||
if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func->getParentOp())) {
|
||||
// Currently, outside compilation and model parallelism are not supported
|
||||
// together.
|
||||
assert(num_cores_per_replica == 1);
|
||||
builder->setInsertionPoint(cluster_func.getParentOp());
|
||||
builder->setInsertionPoint(cluster_func->getParentOp());
|
||||
}
|
||||
|
||||
Operation* compile_op = BuildCompileOp(
|
||||
@ -711,7 +711,7 @@ LogicalResult Rewrite(
|
||||
// and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
|
||||
// structured lowering.
|
||||
if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
|
||||
cluster_func.getParentOp())) {
|
||||
cluster_func->getParentOp())) {
|
||||
parallel_op.walk([&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
|
||||
key_op.replaceAllUsesWith(compile_op->getResult(1));
|
||||
key_op.erase();
|
||||
|
||||
@ -211,7 +211,7 @@ void IdentifyXlaShardingForComputationOutputs(
|
||||
void IdentifyXlaShardingForTPUComputation(
|
||||
Builder* builder, tf_device::ClusterFuncOp cluster_func) {
|
||||
// Look up function definition from module.
|
||||
FuncOp func = cluster_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||
FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||
cluster_func.func());
|
||||
|
||||
// By default inputs/outputs have maximal sharding and are assigned to logical
|
||||
|
||||
@ -483,7 +483,7 @@ bool HandleHostReplicatedInputs(int64_t index,
|
||||
void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
|
||||
unsigned arg_num) {
|
||||
auto maybe_replicate =
|
||||
llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func.getParentOp());
|
||||
llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func->getParentOp());
|
||||
|
||||
llvm::SmallVector<int64_t, 8> transform_input_indices;
|
||||
for (auto input : llvm::enumerate(cluster_func.operands())) {
|
||||
|
||||
@ -151,7 +151,7 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
||||
|
||||
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
|
||||
auto mirrored_variable_indices_attr =
|
||||
replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
|
||||
replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
|
||||
if (!mirrored_variable_indices_attr) return mapping;
|
||||
|
||||
// Finds the mapping from a replicate argument to an execute operand.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user