Use OpState::operator->() to get to member functions in Operation so we can remove the corresponding methods from OpState.

PiperOrigin-RevId: 346865218
Change-Id: Iafe229bfc8577713031cb352903c9fa61cf077f2
This commit is contained in:
Christian Sigg 2020-12-10 14:25:52 -08:00 committed by TensorFlower Gardener
parent 6bdd30ec15
commit 10c7c897c8
26 changed files with 61 additions and 60 deletions

View File

@ -58,10 +58,10 @@ void AnnotateParameterReplication::runOnOperation() {
ModuleOp m = getOperation();
OpBuilder builder(m.getContext());
m.walk([&](tf_device::ClusterFuncOp cluster_func) {
auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
auto replicate = cluster_func->getParentOfType<tf_device::ReplicateOp>();
if (!replicate) return;
auto mirrored_variable_indices_attr =
replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
llvm::SmallDenseSet<int64_t, 8> mirrored_replicate_args;
if (mirrored_variable_indices_attr) {
for (const auto& mirrored_index : mirrored_variable_indices_attr) {

View File

@ -291,7 +291,7 @@ class ClusterTFOpsByHostPass
void runOnFunction() override {
MLIRContext *context = &getContext();
FuncOp func_op = getOperation();
ModuleOp module_op = func_op.getParentOfType<mlir::ModuleOp>();
ModuleOp module_op = func_op->getParentOfType<mlir::ModuleOp>();
llvm::Optional<llvm::StringMap<FunctionMetadata>> metadatas =
GetFunctionMetadatas(func_op);

View File

@ -42,7 +42,7 @@ void ConstantOpDeviceAssignmentPass::runOnOperation() {
module.walk([&](TF::ConstOp op) {
// Keep the ConstOp if the op already have the device attribute.
if (StringAttr device_attr = op.getAttrOfType<StringAttr>(kDeviceAttr)) {
if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
return WalkResult::advance();
}
OpBuilder builder(op);

View File

@ -66,7 +66,7 @@ struct ExecutorIslandCoarsening
// that is closest to the island in the graph. If no candidate can be found or
// the op found is not an island, an empty optional is returned.
llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
Operation* graph_op = island.getParentOp();
Operation* graph_op = island->getParentOp();
Operation* candidate = nullptr;
// Check island control operands.
@ -95,7 +95,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
// an op, that is closest to the island in the graph. If no candidate can be
// found or the op found is not an island, an empty optional is returned.
llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
Operation* graph_op = island.getParentOp();
Operation* graph_op = island->getParentOp();
Operation* candidate = nullptr;
// Check island control results.

View File

@ -78,7 +78,7 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
// in a new module to run the V1 bridge there.
SmallVector<IslandOp, 8> islands_to_outline;
getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) {
auto island_op = cast<IslandOp>(replicate_op.getParentOp());
auto island_op = cast<IslandOp>(replicate_op->getParentOp());
if (!island_op || island_op.WrapsSingleOp()) return;
islands_to_outline.push_back(island_op);
});

View File

@ -40,7 +40,7 @@ namespace {
// "tf.entry_function" attribute defined.
bool CanPruneGraph(FuncOp func) {
return func.getName() != "main" ||
func.getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
func->getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
}
// Visits an op's operand if it is an output of an Operation in the same

View File

@ -124,7 +124,7 @@ void LayoutAssignmentPass::runOnFunction() {
// Get runtime devices information from the closest parent module.
RuntimeDevices devices;
if (failed(::tensorflow::GetDevicesFromOp(func.getParentOfType<ModuleOp>(),
if (failed(::tensorflow::GetDevicesFromOp(func->getParentOfType<ModuleOp>(),
&devices)))
return signalPassFailure();

View File

@ -264,7 +264,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() {
// Only if `allow_soft_placement` attribute is true should we mark ops
// for outside compilation.
auto soft_placement_attr =
cluster.getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
if (!(soft_placement_attr && soft_placement_attr.getValue())) {
return WalkResult::advance();
}
@ -281,7 +281,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() {
// Only if `allow_soft_placement` attribute is true should we unmark ops
// for outside compilation.
auto soft_placement_attr =
cluster.getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
if (!(soft_placement_attr && soft_placement_attr.getValue())) {
return;
}

View File

@ -165,7 +165,7 @@ void CreateIslandsFromParallelExecute(
unused_execute_controls.push_back(execute.control());
if (!unused_execute_controls.empty()) {
auto graph_op = island_op.getParentOfType<tf_executor::GraphOp>();
auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
tf_executor::FetchOp fetch = graph_op.GetFetch();
auto fetches = llvm::to_vector<8>(fetch.getOperands());
fetches.append(unused_execute_controls.begin(),

View File

@ -138,7 +138,8 @@ void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
ShapedType shaped_type =
variable_v2_op.getResult().getType().cast<ShapedType>();
TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
StringAttr device_attr = variable_v2_op.getAttrOfType<StringAttr>("device");
StringAttr device_attr =
variable_v2_op->getAttrOfType<StringAttr>("device");
if (!device_attr) device_attr = builder.getStringAttr("");
StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
if (variable_name.empty()) {

View File

@ -210,8 +210,8 @@ using ArgMatcherFn = function_ref<bool(Value, Region&, Value, Region&)>;
bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
if (first.getNumOperands() != second.getNumOperands()) return false;
Region& first_region = *first.getParentRegion();
Region& second_region = *second.getParentRegion();
Region& first_region = *first->getParentRegion();
Region& second_region = *second->getParentRegion();
for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
// Get the defining Op, skipping over casts.

View File

@ -316,7 +316,7 @@ void ReplicateToIslandPass::runOnFunction() {
});
for (tf_executor::IslandOp island_op : replicate_op_islands) {
auto graph_op = island_op.getParentOfType<tf_executor::GraphOp>();
auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
auto replicate_op =
cast<tf_device::ReplicateOp>(island_op.GetBody().front());
if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,

View File

@ -1106,7 +1106,7 @@ LogicalResult HandlePartitionedCallOpCallee(
// Clone the callee before making changes.
SmallString<64> name_base = callee.getName();
auto module = callee.getParentOfType<ModuleOp>();
auto module = callee->getParentOfType<ModuleOp>();
name_base += "_resource_lifted";
auto name = name_base;
callee = callee.clone();
@ -1376,7 +1376,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
lifted_partitioned_call_callees;
if (failed(HoistForControlFlow(
&function.front(), cast<ModuleOp>(function.getParentOp()),
&function.front(), cast<ModuleOp>(function->getParentOp()),
/*vars_initialized=*/false, &lifted_partitioned_call_callees)))
return failure();

View File

@ -117,7 +117,7 @@ void EliminateUnusedResults(
// multiple uses or unknown uses (for external functions). The cloned function
// will be marked as private.
FuncOp CloneFunctionIfNeeded(FuncOp func) {
ModuleOp module = func.getParentOfType<ModuleOp>();
ModuleOp module = func->getParentOfType<ModuleOp>();
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
if (func_uses.hasValue() && llvm::hasSingleElement(func_uses.getValue()))
return func;

View File

@ -247,7 +247,7 @@ bool CanInferTensorListElementType(Value tensorlist,
continue;
}
if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
Operation* parent = yield.getParentOp();
Operation* parent = yield->getParentOp();
if (!CanInferTensorListElementType(
parent->getResult(use.getOperandNumber()), initial_element_shape,
potential_element_type))
@ -619,7 +619,7 @@ ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context,
ArrayRef<FuncOp> ShapeInference::GetCallers(FuncOp fn) {
auto pair = callers_of_func_.try_emplace(fn);
if (pair.second) {
ModuleOp module = fn.getParentOfType<ModuleOp>();
ModuleOp module = fn->getParentOfType<ModuleOp>();
auto uses = mlir::SymbolTable::getSymbolUses(fn.getOperation(), module);
if (uses) {
pair.first->second.reserve(pair.first->second.size());

View File

@ -60,7 +60,7 @@ class TensorDeviceCopyConversionPass
arg_device = attr;
}
StringAttr op_device = op.getAttrOfType<StringAttr>(kDeviceAttr);
StringAttr op_device = op->getAttrOfType<StringAttr>(kDeviceAttr);
if (!op_device) op_device = empty_string;
// Skip the folding logic if the argument's device is different from the
// operation's device.

View File

@ -116,7 +116,7 @@ void TPUColocateCompositeResourceOps::runOnFunction() {
OpBuilder builder(&getContext());
for (auto execute_launch : execute_launches) {
auto replicate = execute_launch.getParentOfType<tf_device::ReplicateOp>();
auto replicate = execute_launch->getParentOfType<tf_device::ReplicateOp>();
if (!replicate) continue;
ColocateCompositeResourceOpsInReplicate(replicate, &builder);

View File

@ -109,7 +109,7 @@ bool IsSupportedInputOp(
};
// Check all generator aliases (ops or function argument) are on CPU.
FuncOp func = iterator_op.getParentOfType<FuncOp>();
FuncOp func = iterator_op->getParentOfType<FuncOp>();
return llvm::all_of(aliases, [&](Value alias) {
// Ignore non-generator aliases.
if (!is_generator(alias)) return true;
@ -230,7 +230,7 @@ void HandleCompileAndExecutes(
bool metadata_updated = false;
auto maybe_replicate =
execute_launches.front().getParentOfType<tf_device::ReplicateOp>();
execute_launches.front()->getParentOfType<tf_device::ReplicateOp>();
for (auto execute_and_input_mapping :
llvm::zip(execute_launches, input_mappings)) {
@ -284,7 +284,7 @@ void TPUDynamicLayoutPass::runOnFunction(
func.walk([&](TF::_TPUCompileMlirOp compile) {
// Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
auto compile_launch =
llvm::dyn_cast<tf_device::LaunchOp>(compile.getParentOp());
llvm::dyn_cast<tf_device::LaunchOp>(compile->getParentOp());
if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
@ -295,7 +295,7 @@ void TPUDynamicLayoutPass::runOnFunction(
auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
if (!execute) return;
auto execute_launch =
llvm::dyn_cast<tf_device::LaunchOp>(execute.getParentOp());
llvm::dyn_cast<tf_device::LaunchOp>(execute->getParentOp());
if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
execute_launches.push_back(execute_launch);
}

View File

@ -180,7 +180,7 @@ void AnnotateFunctionArgumentsWithPaddings(
LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
SymbolTable* symbol_table) {
auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
auto replicate = cluster_func->getParentOfType<tf_device::ReplicateOp>();
// LaunchFunc is not replicated, there will be no padding.
if (!replicate) return success();
@ -188,7 +188,7 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
if (!func) return success();
auto replicated_input_indices_attr =
replicate.getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
replicate->getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
if (!replicated_input_indices_attr) return success();
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices =

View File

@ -131,7 +131,7 @@ llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
const TF::SideEffectAnalysis& side_effect_analysis,
tf_device::ClusterOp cluster) {
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
cluster.getParentOfType<FuncOp>());
cluster->getParentOfType<FuncOp>());
Region* cluster_region = &cluster.body();
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
@ -227,7 +227,7 @@ void FindOutsideCompiledOpsAtTailAndClusterResults(
llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
llvm::SmallVectorImpl<Value>* cluster_results) {
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
cluster.getParentOfType<FuncOp>());
cluster->getParentOfType<FuncOp>());
Region* cluster_region = &cluster.body();
llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
Operation* terminator = cluster.GetBody().getTerminator();

View File

@ -755,7 +755,7 @@ void MoveOutsideCompiledOps(
// If there is no replication/data parallelism, it is assumed the device
// ordinal is always 0 (e.g. /device:TPU:0). In that case, a constant 0
// attribute can be used instead for _XlaSendFromHost/_XlaRecvAtHost ops.
if (tpu_cluster.getParentOfType<tf_device::ReplicateOp>()) {
if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
auto device_ordinal_op =
builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
host_launch_op.getLoc(),

View File

@ -127,7 +127,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
VariableAccessesForTPUExecute infos;
Attribute device_attr = execute_launch.deviceAttr();
if (check_device && !device_attr) return infos;
auto func = execute_launch.getParentOfType<mlir::FuncOp>();
auto func = execute_launch->getParentOfType<mlir::FuncOp>();
// Track the first read op found, which is used later to check if there are
// assign ops between it and the TPUExecute op. We will exclude reads before
@ -137,7 +137,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
Operation* first_read = nullptr;
Operation& execute = execute_launch.GetBody().front();
auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
execute_launch.getParentOp());
execute_launch->getParentOp());
Operation* execute_parent =
parallel_execute ? parallel_execute.getOperation() : execute_launch;
// Find inputs that are variable reads.
@ -148,7 +148,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
operand.value().get().getDefiningOp());
if (!read_op) continue;
if (check_same_region &&
read_op.getParentRegion() != execute_parent->getParentRegion())
read_op->getParentRegion() != execute_parent->getParentRegion())
continue;
auto resource = read_op.resource();
@ -240,7 +240,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
auto execute_outputs =
parallel_execute
? parallel_execute.GetRegionOutputs(
execute_launch.getParentRegion()->getRegionNumber())
execute_launch->getParentRegion()->getRegionNumber())
: execute_launch.getResults();
for (auto execute_output : llvm::enumerate(execute_outputs)) {
// TODO(lyandy): Handle updates to resource writes by remapping to parent
@ -340,7 +340,7 @@ void ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,
llvm::SmallVector<Type, 8> output_types;
const int parallel_execute_num_results = parallel_execute_op->getNumResults();
output_types.reserve(parallel_execute_num_results);
Region* execute_region = merged_execute_launch.getParentRegion();
Region* execute_region = merged_execute_launch->getParentRegion();
const int region_index = execute_region->getRegionNumber();
const int num_results_before_region =
AppendTypes(&output_types, parallel_execute, 0, region_index);
@ -547,7 +547,7 @@ void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,
merged_execute_launch.GetBody().getTerminator());
if (auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
execute_launch.getParentOp()))
execute_launch->getParentOp()))
ReplaceParallelExecute(parallel_execute, execute_launch,
merged_execute_launch, infos, builder);
else
@ -591,11 +591,11 @@ void TPUMergeVariablesWithExecutePass::runOnFunction() {
for (auto execute_launch : execute_launches) {
OpBuilder builder(&getContext());
const bool parent_is_replicate =
llvm::isa<tf_device::ReplicateOp>(execute_launch.getParentOp()) ||
llvm::isa<tf_device::ReplicateOp>(execute_launch->getParentOp()) ||
(llvm::isa<tf_device::ParallelExecuteOp>(
execute_launch.getParentOp()) &&
execute_launch->getParentOp()) &&
llvm::isa<tf_device::ReplicateOp>(
execute_launch.getParentOp()->getParentOp()));
execute_launch->getParentOp()->getParentOp()));
// If this is inside a tf_device::ReplicateOp, the variables are guaranteed
// to be on the same device as the TPUExecute op. Skip device checking in

View File

@ -106,14 +106,14 @@ std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
std::string* serialized_func_module) {
ModuleOp module = entry_func.getParentOfType<ModuleOp>();
ModuleOp module = entry_func->getParentOfType<ModuleOp>();
SymbolTable entry_module_table(module);
llvm::SmallVector<FuncOp, 4> referenced({entry_func});
// Create a new module to hold func and all referenced functions.
OwningModuleRef module_for_func =
ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
auto parent_module = entry_func.getParentOfType<ModuleOp>();
auto parent_module = entry_func->getParentOfType<ModuleOp>();
auto versions_attr = parent_module.getAttr(kVersionsAttr);
if (!versions_attr)
return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
@ -165,7 +165,7 @@ LogicalResult SetMetadataProtoStepMarkerLocation(
tf_device::ClusterFuncOp op,
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
auto step_marker_location =
op.getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
op->getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
if (!step_marker_location)
return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
@ -190,7 +190,7 @@ LogicalResult SetMetadataProtoStepMarkerLocation(
LogicalResult SetMetadataProtoPaddingMap(
tf_device::ClusterFuncOp op,
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
auto padding_map = op.getAttrOfType<ArrayAttr>(kPaddingMapAttr);
auto padding_map = op->getAttrOfType<ArrayAttr>(kPaddingMapAttr);
if (!padding_map)
return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr));
@ -234,7 +234,7 @@ LogicalResult SetMetadataProtoArgs(
tf_device::ClusterFuncOp op,
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
auto input_shardings =
op.getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
op->getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
if (!input_shardings)
return op.emitOpError(
CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
@ -289,7 +289,7 @@ LogicalResult SetMetadataProtoRetvals(
tf_device::ClusterFuncOp op,
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
auto output_shardings =
op.getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
op->getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
if (!output_shardings)
return op.emitOpError(
CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
@ -329,7 +329,7 @@ LogicalResult SetMetadataProtoFromClusterFuncOp(
if (xla_device_assignment.hasValue())
*metadata->mutable_device_assignment() =
std::move(xla_device_assignment.getValue());
auto use_spmd_attr = op.getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
auto use_spmd_attr = op->getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
if (!use_spmd_attr)
return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
@ -400,7 +400,7 @@ Operation* BuildCompileOp(
}
FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
FuncOp func = cluster_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
func_attr.getValue());
std::string txt_module;
@ -637,16 +637,16 @@ LogicalResult Rewrite(
OpBuilder* builder) {
// Skip non-tpu device cluster_func.
auto replicate_attr =
cluster_func.getAttrOfType<StringAttr>("_tpu_replicate");
cluster_func->getAttrOfType<StringAttr>("_tpu_replicate");
if (!replicate_attr) return success();
// Collect `num_replicas` and `num_cores_per_replica` attributes.
int num_replicas = 1;
tf_device::ReplicateOp replicate =
cluster_func.getParentOfType<tf_device::ReplicateOp>();
cluster_func->getParentOfType<tf_device::ReplicateOp>();
if (replicate) num_replicas = replicate.n();
auto num_cores_per_replica_attr = cluster_func.getAttrOfType<IntegerAttr>(
auto num_cores_per_replica_attr = cluster_func->getAttrOfType<IntegerAttr>(
tensorflow::kNumCoresPerReplicaAttr);
if (!num_cores_per_replica_attr)
return cluster_func.emitOpError(
@ -655,12 +655,12 @@ LogicalResult Rewrite(
int num_cores_per_replica = num_cores_per_replica_attr.getInt();
auto topology_attr =
cluster_func.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
cluster_func->getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
if (!topology_attr)
return cluster_func.emitOpError(
CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
auto device_assignment_attr = cluster_func.getAttrOfType<mlir::ArrayAttr>(
auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
tensorflow::kDeviceAssignmentAttr);
if (!device_assignment_attr)
return cluster_func.emitOpError(
@ -692,11 +692,11 @@ LogicalResult Rewrite(
// Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
// parallel_execute region if it exists.
if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func.getParentOp())) {
if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func->getParentOp())) {
// Currently, outside compilation and model parallelism are not supported
// together.
assert(num_cores_per_replica == 1);
builder->setInsertionPoint(cluster_func.getParentOp());
builder->setInsertionPoint(cluster_func->getParentOp());
}
Operation* compile_op = BuildCompileOp(
@ -711,7 +711,7 @@ LogicalResult Rewrite(
// and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
// structured lowering.
if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
cluster_func.getParentOp())) {
cluster_func->getParentOp())) {
parallel_op.walk([&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
key_op.replaceAllUsesWith(compile_op->getResult(1));
key_op.erase();

View File

@ -211,7 +211,7 @@ void IdentifyXlaShardingForComputationOutputs(
void IdentifyXlaShardingForTPUComputation(
Builder* builder, tf_device::ClusterFuncOp cluster_func) {
// Look up function definition from module.
FuncOp func = cluster_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
cluster_func.func());
// By default inputs/outputs have maximal sharding and are assigned to logical

View File

@ -483,7 +483,7 @@ bool HandleHostReplicatedInputs(int64_t index,
void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
unsigned arg_num) {
auto maybe_replicate =
llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func.getParentOp());
llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func->getParentOp());
llvm::SmallVector<int64_t, 8> transform_input_indices;
for (auto input : llvm::enumerate(cluster_func.operands())) {

View File

@ -151,7 +151,7 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
auto mirrored_variable_indices_attr =
replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
if (!mirrored_variable_indices_attr) return mapping;
// Finds the mapping from a replicate argument to an execute operand.