Use OpState::operator->() to get to member functions in Operation so we can remove the corresponding methods from OpState.
PiperOrigin-RevId: 346865218 Change-Id: Iafe229bfc8577713031cb352903c9fa61cf077f2
This commit is contained in:
parent
6bdd30ec15
commit
10c7c897c8
@ -58,10 +58,10 @@ void AnnotateParameterReplication::runOnOperation() {
|
|||||||
ModuleOp m = getOperation();
|
ModuleOp m = getOperation();
|
||||||
OpBuilder builder(m.getContext());
|
OpBuilder builder(m.getContext());
|
||||||
m.walk([&](tf_device::ClusterFuncOp cluster_func) {
|
m.walk([&](tf_device::ClusterFuncOp cluster_func) {
|
||||||
auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
|
auto replicate = cluster_func->getParentOfType<tf_device::ReplicateOp>();
|
||||||
if (!replicate) return;
|
if (!replicate) return;
|
||||||
auto mirrored_variable_indices_attr =
|
auto mirrored_variable_indices_attr =
|
||||||
replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
|
replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
|
||||||
llvm::SmallDenseSet<int64_t, 8> mirrored_replicate_args;
|
llvm::SmallDenseSet<int64_t, 8> mirrored_replicate_args;
|
||||||
if (mirrored_variable_indices_attr) {
|
if (mirrored_variable_indices_attr) {
|
||||||
for (const auto& mirrored_index : mirrored_variable_indices_attr) {
|
for (const auto& mirrored_index : mirrored_variable_indices_attr) {
|
||||||
|
|||||||
@ -291,7 +291,7 @@ class ClusterTFOpsByHostPass
|
|||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
FuncOp func_op = getOperation();
|
FuncOp func_op = getOperation();
|
||||||
ModuleOp module_op = func_op.getParentOfType<mlir::ModuleOp>();
|
ModuleOp module_op = func_op->getParentOfType<mlir::ModuleOp>();
|
||||||
|
|
||||||
llvm::Optional<llvm::StringMap<FunctionMetadata>> metadatas =
|
llvm::Optional<llvm::StringMap<FunctionMetadata>> metadatas =
|
||||||
GetFunctionMetadatas(func_op);
|
GetFunctionMetadatas(func_op);
|
||||||
|
|||||||
@ -42,7 +42,7 @@ void ConstantOpDeviceAssignmentPass::runOnOperation() {
|
|||||||
|
|
||||||
module.walk([&](TF::ConstOp op) {
|
module.walk([&](TF::ConstOp op) {
|
||||||
// Keep the ConstOp if the op already have the device attribute.
|
// Keep the ConstOp if the op already have the device attribute.
|
||||||
if (StringAttr device_attr = op.getAttrOfType<StringAttr>(kDeviceAttr)) {
|
if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
}
|
}
|
||||||
OpBuilder builder(op);
|
OpBuilder builder(op);
|
||||||
|
|||||||
@ -66,7 +66,7 @@ struct ExecutorIslandCoarsening
|
|||||||
// that is closest to the island in the graph. If no candidate can be found or
|
// that is closest to the island in the graph. If no candidate can be found or
|
||||||
// the op found is not an island, an empty optional is returned.
|
// the op found is not an island, an empty optional is returned.
|
||||||
llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
|
llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
|
||||||
Operation* graph_op = island.getParentOp();
|
Operation* graph_op = island->getParentOp();
|
||||||
Operation* candidate = nullptr;
|
Operation* candidate = nullptr;
|
||||||
|
|
||||||
// Check island control operands.
|
// Check island control operands.
|
||||||
@ -95,7 +95,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
|
|||||||
// an op, that is closest to the island in the graph. If no candidate can be
|
// an op, that is closest to the island in the graph. If no candidate can be
|
||||||
// found or the op found is not an island, an empty optional is returned.
|
// found or the op found is not an island, an empty optional is returned.
|
||||||
llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
|
llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
|
||||||
Operation* graph_op = island.getParentOp();
|
Operation* graph_op = island->getParentOp();
|
||||||
Operation* candidate = nullptr;
|
Operation* candidate = nullptr;
|
||||||
|
|
||||||
// Check island control results.
|
// Check island control results.
|
||||||
|
|||||||
@ -78,7 +78,7 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
|||||||
// in a new module to run the V1 bridge there.
|
// in a new module to run the V1 bridge there.
|
||||||
SmallVector<IslandOp, 8> islands_to_outline;
|
SmallVector<IslandOp, 8> islands_to_outline;
|
||||||
getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) {
|
getOperation().walk([&](TF::TPUReplicateMetadataOp replicate_op) {
|
||||||
auto island_op = cast<IslandOp>(replicate_op.getParentOp());
|
auto island_op = cast<IslandOp>(replicate_op->getParentOp());
|
||||||
if (!island_op || island_op.WrapsSingleOp()) return;
|
if (!island_op || island_op.WrapsSingleOp()) return;
|
||||||
islands_to_outline.push_back(island_op);
|
islands_to_outline.push_back(island_op);
|
||||||
});
|
});
|
||||||
|
|||||||
@ -40,7 +40,7 @@ namespace {
|
|||||||
// "tf.entry_function" attribute defined.
|
// "tf.entry_function" attribute defined.
|
||||||
bool CanPruneGraph(FuncOp func) {
|
bool CanPruneGraph(FuncOp func) {
|
||||||
return func.getName() != "main" ||
|
return func.getName() != "main" ||
|
||||||
func.getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
|
func->getAttrOfType<DictionaryAttr>("tf.entry_function") != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Visits an op's operand if it is an output of an Operation in the same
|
// Visits an op's operand if it is an output of an Operation in the same
|
||||||
|
|||||||
@ -124,7 +124,7 @@ void LayoutAssignmentPass::runOnFunction() {
|
|||||||
|
|
||||||
// Get runtime devices information from the closest parent module.
|
// Get runtime devices information from the closest parent module.
|
||||||
RuntimeDevices devices;
|
RuntimeDevices devices;
|
||||||
if (failed(::tensorflow::GetDevicesFromOp(func.getParentOfType<ModuleOp>(),
|
if (failed(::tensorflow::GetDevicesFromOp(func->getParentOfType<ModuleOp>(),
|
||||||
&devices)))
|
&devices)))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
||||||
|
|||||||
@ -264,7 +264,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() {
|
|||||||
// Only if `allow_soft_placement` attribute is true should we mark ops
|
// Only if `allow_soft_placement` attribute is true should we mark ops
|
||||||
// for outside compilation.
|
// for outside compilation.
|
||||||
auto soft_placement_attr =
|
auto soft_placement_attr =
|
||||||
cluster.getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
|
cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
|
||||||
if (!(soft_placement_attr && soft_placement_attr.getValue())) {
|
if (!(soft_placement_attr && soft_placement_attr.getValue())) {
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
}
|
}
|
||||||
@ -281,7 +281,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() {
|
|||||||
// Only if `allow_soft_placement` attribute is true should we unmark ops
|
// Only if `allow_soft_placement` attribute is true should we unmark ops
|
||||||
// for outside compilation.
|
// for outside compilation.
|
||||||
auto soft_placement_attr =
|
auto soft_placement_attr =
|
||||||
cluster.getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
|
cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
|
||||||
if (!(soft_placement_attr && soft_placement_attr.getValue())) {
|
if (!(soft_placement_attr && soft_placement_attr.getValue())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -165,7 +165,7 @@ void CreateIslandsFromParallelExecute(
|
|||||||
unused_execute_controls.push_back(execute.control());
|
unused_execute_controls.push_back(execute.control());
|
||||||
|
|
||||||
if (!unused_execute_controls.empty()) {
|
if (!unused_execute_controls.empty()) {
|
||||||
auto graph_op = island_op.getParentOfType<tf_executor::GraphOp>();
|
auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
|
||||||
tf_executor::FetchOp fetch = graph_op.GetFetch();
|
tf_executor::FetchOp fetch = graph_op.GetFetch();
|
||||||
auto fetches = llvm::to_vector<8>(fetch.getOperands());
|
auto fetches = llvm::to_vector<8>(fetch.getOperands());
|
||||||
fetches.append(unused_execute_controls.begin(),
|
fetches.append(unused_execute_controls.begin(),
|
||||||
|
|||||||
@ -138,7 +138,8 @@ void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
|
|||||||
ShapedType shaped_type =
|
ShapedType shaped_type =
|
||||||
variable_v2_op.getResult().getType().cast<ShapedType>();
|
variable_v2_op.getResult().getType().cast<ShapedType>();
|
||||||
TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
|
TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
|
||||||
StringAttr device_attr = variable_v2_op.getAttrOfType<StringAttr>("device");
|
StringAttr device_attr =
|
||||||
|
variable_v2_op->getAttrOfType<StringAttr>("device");
|
||||||
if (!device_attr) device_attr = builder.getStringAttr("");
|
if (!device_attr) device_attr = builder.getStringAttr("");
|
||||||
StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
|
StringRef variable_name = GetNodeNameFromClassAttr(variable_v2_op);
|
||||||
if (variable_name.empty()) {
|
if (variable_name.empty()) {
|
||||||
|
|||||||
@ -210,8 +210,8 @@ using ArgMatcherFn = function_ref<bool(Value, Region&, Value, Region&)>;
|
|||||||
bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
|
bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) {
|
||||||
if (first.getNumOperands() != second.getNumOperands()) return false;
|
if (first.getNumOperands() != second.getNumOperands()) return false;
|
||||||
|
|
||||||
Region& first_region = *first.getParentRegion();
|
Region& first_region = *first->getParentRegion();
|
||||||
Region& second_region = *second.getParentRegion();
|
Region& second_region = *second->getParentRegion();
|
||||||
|
|
||||||
for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
|
for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
|
||||||
// Get the defining Op, skipping over casts.
|
// Get the defining Op, skipping over casts.
|
||||||
|
|||||||
@ -316,7 +316,7 @@ void ReplicateToIslandPass::runOnFunction() {
|
|||||||
});
|
});
|
||||||
|
|
||||||
for (tf_executor::IslandOp island_op : replicate_op_islands) {
|
for (tf_executor::IslandOp island_op : replicate_op_islands) {
|
||||||
auto graph_op = island_op.getParentOfType<tf_executor::GraphOp>();
|
auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
|
||||||
auto replicate_op =
|
auto replicate_op =
|
||||||
cast<tf_device::ReplicateOp>(island_op.GetBody().front());
|
cast<tf_device::ReplicateOp>(island_op.GetBody().front());
|
||||||
if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,
|
if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op,
|
||||||
|
|||||||
@ -1106,7 +1106,7 @@ LogicalResult HandlePartitionedCallOpCallee(
|
|||||||
|
|
||||||
// Clone the callee before making changes.
|
// Clone the callee before making changes.
|
||||||
SmallString<64> name_base = callee.getName();
|
SmallString<64> name_base = callee.getName();
|
||||||
auto module = callee.getParentOfType<ModuleOp>();
|
auto module = callee->getParentOfType<ModuleOp>();
|
||||||
name_base += "_resource_lifted";
|
name_base += "_resource_lifted";
|
||||||
auto name = name_base;
|
auto name = name_base;
|
||||||
callee = callee.clone();
|
callee = callee.clone();
|
||||||
@ -1376,7 +1376,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) {
|
|||||||
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
|
llvm::SmallDenseMap<llvm::StringRef, PartitionedCallLiftingInfo>
|
||||||
lifted_partitioned_call_callees;
|
lifted_partitioned_call_callees;
|
||||||
if (failed(HoistForControlFlow(
|
if (failed(HoistForControlFlow(
|
||||||
&function.front(), cast<ModuleOp>(function.getParentOp()),
|
&function.front(), cast<ModuleOp>(function->getParentOp()),
|
||||||
/*vars_initialized=*/false, &lifted_partitioned_call_callees)))
|
/*vars_initialized=*/false, &lifted_partitioned_call_callees)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
|||||||
@ -117,7 +117,7 @@ void EliminateUnusedResults(
|
|||||||
// multiple uses or unknown uses (for external functions). The cloned function
|
// multiple uses or unknown uses (for external functions). The cloned function
|
||||||
// will be marked as private.
|
// will be marked as private.
|
||||||
FuncOp CloneFunctionIfNeeded(FuncOp func) {
|
FuncOp CloneFunctionIfNeeded(FuncOp func) {
|
||||||
ModuleOp module = func.getParentOfType<ModuleOp>();
|
ModuleOp module = func->getParentOfType<ModuleOp>();
|
||||||
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion());
|
||||||
if (func_uses.hasValue() && llvm::hasSingleElement(func_uses.getValue()))
|
if (func_uses.hasValue() && llvm::hasSingleElement(func_uses.getValue()))
|
||||||
return func;
|
return func;
|
||||||
|
|||||||
@ -247,7 +247,7 @@ bool CanInferTensorListElementType(Value tensorlist,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
|
if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
|
||||||
Operation* parent = yield.getParentOp();
|
Operation* parent = yield->getParentOp();
|
||||||
if (!CanInferTensorListElementType(
|
if (!CanInferTensorListElementType(
|
||||||
parent->getResult(use.getOperandNumber()), initial_element_shape,
|
parent->getResult(use.getOperandNumber()), initial_element_shape,
|
||||||
potential_element_type))
|
potential_element_type))
|
||||||
@ -619,7 +619,7 @@ ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context,
|
|||||||
ArrayRef<FuncOp> ShapeInference::GetCallers(FuncOp fn) {
|
ArrayRef<FuncOp> ShapeInference::GetCallers(FuncOp fn) {
|
||||||
auto pair = callers_of_func_.try_emplace(fn);
|
auto pair = callers_of_func_.try_emplace(fn);
|
||||||
if (pair.second) {
|
if (pair.second) {
|
||||||
ModuleOp module = fn.getParentOfType<ModuleOp>();
|
ModuleOp module = fn->getParentOfType<ModuleOp>();
|
||||||
auto uses = mlir::SymbolTable::getSymbolUses(fn.getOperation(), module);
|
auto uses = mlir::SymbolTable::getSymbolUses(fn.getOperation(), module);
|
||||||
if (uses) {
|
if (uses) {
|
||||||
pair.first->second.reserve(pair.first->second.size());
|
pair.first->second.reserve(pair.first->second.size());
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class TensorDeviceCopyConversionPass
|
|||||||
arg_device = attr;
|
arg_device = attr;
|
||||||
}
|
}
|
||||||
|
|
||||||
StringAttr op_device = op.getAttrOfType<StringAttr>(kDeviceAttr);
|
StringAttr op_device = op->getAttrOfType<StringAttr>(kDeviceAttr);
|
||||||
if (!op_device) op_device = empty_string;
|
if (!op_device) op_device = empty_string;
|
||||||
// Skip the folding logic if the argument's device is different from the
|
// Skip the folding logic if the argument's device is different from the
|
||||||
// operation's device.
|
// operation's device.
|
||||||
|
|||||||
@ -116,7 +116,7 @@ void TPUColocateCompositeResourceOps::runOnFunction() {
|
|||||||
|
|
||||||
OpBuilder builder(&getContext());
|
OpBuilder builder(&getContext());
|
||||||
for (auto execute_launch : execute_launches) {
|
for (auto execute_launch : execute_launches) {
|
||||||
auto replicate = execute_launch.getParentOfType<tf_device::ReplicateOp>();
|
auto replicate = execute_launch->getParentOfType<tf_device::ReplicateOp>();
|
||||||
if (!replicate) continue;
|
if (!replicate) continue;
|
||||||
|
|
||||||
ColocateCompositeResourceOpsInReplicate(replicate, &builder);
|
ColocateCompositeResourceOpsInReplicate(replicate, &builder);
|
||||||
|
|||||||
@ -109,7 +109,7 @@ bool IsSupportedInputOp(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Check all generator aliases (ops or function argument) are on CPU.
|
// Check all generator aliases (ops or function argument) are on CPU.
|
||||||
FuncOp func = iterator_op.getParentOfType<FuncOp>();
|
FuncOp func = iterator_op->getParentOfType<FuncOp>();
|
||||||
return llvm::all_of(aliases, [&](Value alias) {
|
return llvm::all_of(aliases, [&](Value alias) {
|
||||||
// Ignore non-generator aliases.
|
// Ignore non-generator aliases.
|
||||||
if (!is_generator(alias)) return true;
|
if (!is_generator(alias)) return true;
|
||||||
@ -230,7 +230,7 @@ void HandleCompileAndExecutes(
|
|||||||
|
|
||||||
bool metadata_updated = false;
|
bool metadata_updated = false;
|
||||||
auto maybe_replicate =
|
auto maybe_replicate =
|
||||||
execute_launches.front().getParentOfType<tf_device::ReplicateOp>();
|
execute_launches.front()->getParentOfType<tf_device::ReplicateOp>();
|
||||||
|
|
||||||
for (auto execute_and_input_mapping :
|
for (auto execute_and_input_mapping :
|
||||||
llvm::zip(execute_launches, input_mappings)) {
|
llvm::zip(execute_launches, input_mappings)) {
|
||||||
@ -284,7 +284,7 @@ void TPUDynamicLayoutPass::runOnFunction(
|
|||||||
func.walk([&](TF::_TPUCompileMlirOp compile) {
|
func.walk([&](TF::_TPUCompileMlirOp compile) {
|
||||||
// Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
|
// Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
|
||||||
auto compile_launch =
|
auto compile_launch =
|
||||||
llvm::dyn_cast<tf_device::LaunchOp>(compile.getParentOp());
|
llvm::dyn_cast<tf_device::LaunchOp>(compile->getParentOp());
|
||||||
if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
|
if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
|
||||||
|
|
||||||
llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
|
llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
|
||||||
@ -295,7 +295,7 @@ void TPUDynamicLayoutPass::runOnFunction(
|
|||||||
auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
|
auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
|
||||||
if (!execute) return;
|
if (!execute) return;
|
||||||
auto execute_launch =
|
auto execute_launch =
|
||||||
llvm::dyn_cast<tf_device::LaunchOp>(execute.getParentOp());
|
llvm::dyn_cast<tf_device::LaunchOp>(execute->getParentOp());
|
||||||
if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
|
if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
|
||||||
execute_launches.push_back(execute_launch);
|
execute_launches.push_back(execute_launch);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -180,7 +180,7 @@ void AnnotateFunctionArgumentsWithPaddings(
|
|||||||
|
|
||||||
LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
|
LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
|
||||||
SymbolTable* symbol_table) {
|
SymbolTable* symbol_table) {
|
||||||
auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
|
auto replicate = cluster_func->getParentOfType<tf_device::ReplicateOp>();
|
||||||
// LaunchFunc is not replicated, there will be no padding.
|
// LaunchFunc is not replicated, there will be no padding.
|
||||||
if (!replicate) return success();
|
if (!replicate) return success();
|
||||||
|
|
||||||
@ -188,7 +188,7 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
|
|||||||
if (!func) return success();
|
if (!func) return success();
|
||||||
|
|
||||||
auto replicated_input_indices_attr =
|
auto replicated_input_indices_attr =
|
||||||
replicate.getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
|
replicate->getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
|
||||||
if (!replicated_input_indices_attr) return success();
|
if (!replicated_input_indices_attr) return success();
|
||||||
|
|
||||||
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices =
|
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices =
|
||||||
|
|||||||
@ -131,7 +131,7 @@ llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
|||||||
const TF::SideEffectAnalysis& side_effect_analysis,
|
const TF::SideEffectAnalysis& side_effect_analysis,
|
||||||
tf_device::ClusterOp cluster) {
|
tf_device::ClusterOp cluster) {
|
||||||
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
|
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
|
||||||
cluster.getParentOfType<FuncOp>());
|
cluster->getParentOfType<FuncOp>());
|
||||||
Region* cluster_region = &cluster.body();
|
Region* cluster_region = &cluster.body();
|
||||||
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
|
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
|
||||||
|
|
||||||
@ -227,7 +227,7 @@ void FindOutsideCompiledOpsAtTailAndClusterResults(
|
|||||||
llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
|
llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
|
||||||
llvm::SmallVectorImpl<Value>* cluster_results) {
|
llvm::SmallVectorImpl<Value>* cluster_results) {
|
||||||
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
|
const auto& analysis = side_effect_analysis.GetAnalysisForFunc(
|
||||||
cluster.getParentOfType<FuncOp>());
|
cluster->getParentOfType<FuncOp>());
|
||||||
Region* cluster_region = &cluster.body();
|
Region* cluster_region = &cluster.body();
|
||||||
llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
|
llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
|
||||||
Operation* terminator = cluster.GetBody().getTerminator();
|
Operation* terminator = cluster.GetBody().getTerminator();
|
||||||
|
|||||||
@ -755,7 +755,7 @@ void MoveOutsideCompiledOps(
|
|||||||
// If there is no replication/data parallelism, it is assumed the device
|
// If there is no replication/data parallelism, it is assumed the device
|
||||||
// ordinal is always 0 (e.g. /device:TPU:0). In that case, a constant 0
|
// ordinal is always 0 (e.g. /device:TPU:0). In that case, a constant 0
|
||||||
// attribute can be used instead for _XlaSendFromHost/_XlaRecvAtHost ops.
|
// attribute can be used instead for _XlaSendFromHost/_XlaRecvAtHost ops.
|
||||||
if (tpu_cluster.getParentOfType<tf_device::ReplicateOp>()) {
|
if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
|
||||||
auto device_ordinal_op =
|
auto device_ordinal_op =
|
||||||
builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
|
builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
|
||||||
host_launch_op.getLoc(),
|
host_launch_op.getLoc(),
|
||||||
|
|||||||
@ -127,7 +127,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
|||||||
VariableAccessesForTPUExecute infos;
|
VariableAccessesForTPUExecute infos;
|
||||||
Attribute device_attr = execute_launch.deviceAttr();
|
Attribute device_attr = execute_launch.deviceAttr();
|
||||||
if (check_device && !device_attr) return infos;
|
if (check_device && !device_attr) return infos;
|
||||||
auto func = execute_launch.getParentOfType<mlir::FuncOp>();
|
auto func = execute_launch->getParentOfType<mlir::FuncOp>();
|
||||||
|
|
||||||
// Track the first read op found, which is used later to check if there are
|
// Track the first read op found, which is used later to check if there are
|
||||||
// assign ops between it and the TPUExecute op. We will exclude reads before
|
// assign ops between it and the TPUExecute op. We will exclude reads before
|
||||||
@ -137,7 +137,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
|||||||
Operation* first_read = nullptr;
|
Operation* first_read = nullptr;
|
||||||
Operation& execute = execute_launch.GetBody().front();
|
Operation& execute = execute_launch.GetBody().front();
|
||||||
auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
|
auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
|
||||||
execute_launch.getParentOp());
|
execute_launch->getParentOp());
|
||||||
Operation* execute_parent =
|
Operation* execute_parent =
|
||||||
parallel_execute ? parallel_execute.getOperation() : execute_launch;
|
parallel_execute ? parallel_execute.getOperation() : execute_launch;
|
||||||
// Find inputs that are variable reads.
|
// Find inputs that are variable reads.
|
||||||
@ -148,7 +148,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
|||||||
operand.value().get().getDefiningOp());
|
operand.value().get().getDefiningOp());
|
||||||
if (!read_op) continue;
|
if (!read_op) continue;
|
||||||
if (check_same_region &&
|
if (check_same_region &&
|
||||||
read_op.getParentRegion() != execute_parent->getParentRegion())
|
read_op->getParentRegion() != execute_parent->getParentRegion())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto resource = read_op.resource();
|
auto resource = read_op.resource();
|
||||||
@ -240,7 +240,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
|||||||
auto execute_outputs =
|
auto execute_outputs =
|
||||||
parallel_execute
|
parallel_execute
|
||||||
? parallel_execute.GetRegionOutputs(
|
? parallel_execute.GetRegionOutputs(
|
||||||
execute_launch.getParentRegion()->getRegionNumber())
|
execute_launch->getParentRegion()->getRegionNumber())
|
||||||
: execute_launch.getResults();
|
: execute_launch.getResults();
|
||||||
for (auto execute_output : llvm::enumerate(execute_outputs)) {
|
for (auto execute_output : llvm::enumerate(execute_outputs)) {
|
||||||
// TODO(lyandy): Handle updates to resource writes by remapping to parent
|
// TODO(lyandy): Handle updates to resource writes by remapping to parent
|
||||||
@ -340,7 +340,7 @@ void ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,
|
|||||||
llvm::SmallVector<Type, 8> output_types;
|
llvm::SmallVector<Type, 8> output_types;
|
||||||
const int parallel_execute_num_results = parallel_execute_op->getNumResults();
|
const int parallel_execute_num_results = parallel_execute_op->getNumResults();
|
||||||
output_types.reserve(parallel_execute_num_results);
|
output_types.reserve(parallel_execute_num_results);
|
||||||
Region* execute_region = merged_execute_launch.getParentRegion();
|
Region* execute_region = merged_execute_launch->getParentRegion();
|
||||||
const int region_index = execute_region->getRegionNumber();
|
const int region_index = execute_region->getRegionNumber();
|
||||||
const int num_results_before_region =
|
const int num_results_before_region =
|
||||||
AppendTypes(&output_types, parallel_execute, 0, region_index);
|
AppendTypes(&output_types, parallel_execute, 0, region_index);
|
||||||
@ -547,7 +547,7 @@ void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,
|
|||||||
merged_execute_launch.GetBody().getTerminator());
|
merged_execute_launch.GetBody().getTerminator());
|
||||||
|
|
||||||
if (auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
|
if (auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
|
||||||
execute_launch.getParentOp()))
|
execute_launch->getParentOp()))
|
||||||
ReplaceParallelExecute(parallel_execute, execute_launch,
|
ReplaceParallelExecute(parallel_execute, execute_launch,
|
||||||
merged_execute_launch, infos, builder);
|
merged_execute_launch, infos, builder);
|
||||||
else
|
else
|
||||||
@ -591,11 +591,11 @@ void TPUMergeVariablesWithExecutePass::runOnFunction() {
|
|||||||
for (auto execute_launch : execute_launches) {
|
for (auto execute_launch : execute_launches) {
|
||||||
OpBuilder builder(&getContext());
|
OpBuilder builder(&getContext());
|
||||||
const bool parent_is_replicate =
|
const bool parent_is_replicate =
|
||||||
llvm::isa<tf_device::ReplicateOp>(execute_launch.getParentOp()) ||
|
llvm::isa<tf_device::ReplicateOp>(execute_launch->getParentOp()) ||
|
||||||
(llvm::isa<tf_device::ParallelExecuteOp>(
|
(llvm::isa<tf_device::ParallelExecuteOp>(
|
||||||
execute_launch.getParentOp()) &&
|
execute_launch->getParentOp()) &&
|
||||||
llvm::isa<tf_device::ReplicateOp>(
|
llvm::isa<tf_device::ReplicateOp>(
|
||||||
execute_launch.getParentOp()->getParentOp()));
|
execute_launch->getParentOp()->getParentOp()));
|
||||||
|
|
||||||
// If this is inside a tf_device::ReplicateOp, the variables are guaranteed
|
// If this is inside a tf_device::ReplicateOp, the variables are guaranteed
|
||||||
// to be on the same device as the TPUExecute op. Skip device checking in
|
// to be on the same device as the TPUExecute op. Skip device checking in
|
||||||
|
|||||||
@ -106,14 +106,14 @@ std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
|
|||||||
|
|
||||||
LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
|
LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
|
||||||
std::string* serialized_func_module) {
|
std::string* serialized_func_module) {
|
||||||
ModuleOp module = entry_func.getParentOfType<ModuleOp>();
|
ModuleOp module = entry_func->getParentOfType<ModuleOp>();
|
||||||
SymbolTable entry_module_table(module);
|
SymbolTable entry_module_table(module);
|
||||||
llvm::SmallVector<FuncOp, 4> referenced({entry_func});
|
llvm::SmallVector<FuncOp, 4> referenced({entry_func});
|
||||||
|
|
||||||
// Create a new module to hold func and all referenced functions.
|
// Create a new module to hold func and all referenced functions.
|
||||||
OwningModuleRef module_for_func =
|
OwningModuleRef module_for_func =
|
||||||
ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
|
ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
|
||||||
auto parent_module = entry_func.getParentOfType<ModuleOp>();
|
auto parent_module = entry_func->getParentOfType<ModuleOp>();
|
||||||
auto versions_attr = parent_module.getAttr(kVersionsAttr);
|
auto versions_attr = parent_module.getAttr(kVersionsAttr);
|
||||||
if (!versions_attr)
|
if (!versions_attr)
|
||||||
return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
|
return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
|
||||||
@ -165,7 +165,7 @@ LogicalResult SetMetadataProtoStepMarkerLocation(
|
|||||||
tf_device::ClusterFuncOp op,
|
tf_device::ClusterFuncOp op,
|
||||||
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
||||||
auto step_marker_location =
|
auto step_marker_location =
|
||||||
op.getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
|
op->getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
|
||||||
if (!step_marker_location)
|
if (!step_marker_location)
|
||||||
return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
|
return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ LogicalResult SetMetadataProtoStepMarkerLocation(
|
|||||||
LogicalResult SetMetadataProtoPaddingMap(
|
LogicalResult SetMetadataProtoPaddingMap(
|
||||||
tf_device::ClusterFuncOp op,
|
tf_device::ClusterFuncOp op,
|
||||||
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
||||||
auto padding_map = op.getAttrOfType<ArrayAttr>(kPaddingMapAttr);
|
auto padding_map = op->getAttrOfType<ArrayAttr>(kPaddingMapAttr);
|
||||||
if (!padding_map)
|
if (!padding_map)
|
||||||
return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr));
|
return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr));
|
||||||
|
|
||||||
@ -234,7 +234,7 @@ LogicalResult SetMetadataProtoArgs(
|
|||||||
tf_device::ClusterFuncOp op,
|
tf_device::ClusterFuncOp op,
|
||||||
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
||||||
auto input_shardings =
|
auto input_shardings =
|
||||||
op.getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
|
op->getAttrOfType<ArrayAttr>(tensorflow::kInputShardingAttr);
|
||||||
if (!input_shardings)
|
if (!input_shardings)
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
|
CreateMissingAttributeMsg(tensorflow::kInputShardingAttr));
|
||||||
@ -289,7 +289,7 @@ LogicalResult SetMetadataProtoRetvals(
|
|||||||
tf_device::ClusterFuncOp op,
|
tf_device::ClusterFuncOp op,
|
||||||
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
|
||||||
auto output_shardings =
|
auto output_shardings =
|
||||||
op.getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
|
op->getAttrOfType<ArrayAttr>(tensorflow::kOutputShardingAttr);
|
||||||
if (!output_shardings)
|
if (!output_shardings)
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
|
CreateMissingAttributeMsg(tensorflow::kOutputShardingAttr));
|
||||||
@ -329,7 +329,7 @@ LogicalResult SetMetadataProtoFromClusterFuncOp(
|
|||||||
if (xla_device_assignment.hasValue())
|
if (xla_device_assignment.hasValue())
|
||||||
*metadata->mutable_device_assignment() =
|
*metadata->mutable_device_assignment() =
|
||||||
std::move(xla_device_assignment.getValue());
|
std::move(xla_device_assignment.getValue());
|
||||||
auto use_spmd_attr = op.getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
|
auto use_spmd_attr = op->getAttrOfType<BoolAttr>(kUseXlaSpmdAttr);
|
||||||
if (!use_spmd_attr)
|
if (!use_spmd_attr)
|
||||||
return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
|
return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr));
|
||||||
metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
|
metadata->set_use_spmd_for_xla_partitioning(use_spmd_attr.getValue());
|
||||||
@ -400,7 +400,7 @@ Operation* BuildCompileOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
|
FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
|
||||||
FuncOp func = cluster_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||||
func_attr.getValue());
|
func_attr.getValue());
|
||||||
|
|
||||||
std::string txt_module;
|
std::string txt_module;
|
||||||
@ -637,16 +637,16 @@ LogicalResult Rewrite(
|
|||||||
OpBuilder* builder) {
|
OpBuilder* builder) {
|
||||||
// Skip non-tpu device cluster_func.
|
// Skip non-tpu device cluster_func.
|
||||||
auto replicate_attr =
|
auto replicate_attr =
|
||||||
cluster_func.getAttrOfType<StringAttr>("_tpu_replicate");
|
cluster_func->getAttrOfType<StringAttr>("_tpu_replicate");
|
||||||
if (!replicate_attr) return success();
|
if (!replicate_attr) return success();
|
||||||
|
|
||||||
// Collect `num_replicas` and `num_cores_per_replica` attributes.
|
// Collect `num_replicas` and `num_cores_per_replica` attributes.
|
||||||
int num_replicas = 1;
|
int num_replicas = 1;
|
||||||
tf_device::ReplicateOp replicate =
|
tf_device::ReplicateOp replicate =
|
||||||
cluster_func.getParentOfType<tf_device::ReplicateOp>();
|
cluster_func->getParentOfType<tf_device::ReplicateOp>();
|
||||||
if (replicate) num_replicas = replicate.n();
|
if (replicate) num_replicas = replicate.n();
|
||||||
|
|
||||||
auto num_cores_per_replica_attr = cluster_func.getAttrOfType<IntegerAttr>(
|
auto num_cores_per_replica_attr = cluster_func->getAttrOfType<IntegerAttr>(
|
||||||
tensorflow::kNumCoresPerReplicaAttr);
|
tensorflow::kNumCoresPerReplicaAttr);
|
||||||
if (!num_cores_per_replica_attr)
|
if (!num_cores_per_replica_attr)
|
||||||
return cluster_func.emitOpError(
|
return cluster_func.emitOpError(
|
||||||
@ -655,12 +655,12 @@ LogicalResult Rewrite(
|
|||||||
int num_cores_per_replica = num_cores_per_replica_attr.getInt();
|
int num_cores_per_replica = num_cores_per_replica_attr.getInt();
|
||||||
|
|
||||||
auto topology_attr =
|
auto topology_attr =
|
||||||
cluster_func.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
cluster_func->getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
||||||
if (!topology_attr)
|
if (!topology_attr)
|
||||||
return cluster_func.emitOpError(
|
return cluster_func.emitOpError(
|
||||||
CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
|
CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
|
||||||
|
|
||||||
auto device_assignment_attr = cluster_func.getAttrOfType<mlir::ArrayAttr>(
|
auto device_assignment_attr = cluster_func->getAttrOfType<mlir::ArrayAttr>(
|
||||||
tensorflow::kDeviceAssignmentAttr);
|
tensorflow::kDeviceAssignmentAttr);
|
||||||
if (!device_assignment_attr)
|
if (!device_assignment_attr)
|
||||||
return cluster_func.emitOpError(
|
return cluster_func.emitOpError(
|
||||||
@ -692,11 +692,11 @@ LogicalResult Rewrite(
|
|||||||
|
|
||||||
// Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
|
// Create the TPUCompileMlir and TPUCompileSucceededAssert outside of
|
||||||
// parallel_execute region if it exists.
|
// parallel_execute region if it exists.
|
||||||
if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func.getParentOp())) {
|
if (llvm::isa<tf_device::ParallelExecuteOp>(cluster_func->getParentOp())) {
|
||||||
// Currently, outside compilation and model parallelism are not supported
|
// Currently, outside compilation and model parallelism are not supported
|
||||||
// together.
|
// together.
|
||||||
assert(num_cores_per_replica == 1);
|
assert(num_cores_per_replica == 1);
|
||||||
builder->setInsertionPoint(cluster_func.getParentOp());
|
builder->setInsertionPoint(cluster_func->getParentOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation* compile_op = BuildCompileOp(
|
Operation* compile_op = BuildCompileOp(
|
||||||
@ -711,7 +711,7 @@ LogicalResult Rewrite(
|
|||||||
// and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
|
// and _XlaRecvAtHostOp and _XlaSendFromHostOp are used, update to a more
|
||||||
// structured lowering.
|
// structured lowering.
|
||||||
if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
|
if (auto parallel_op = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
|
||||||
cluster_func.getParentOp())) {
|
cluster_func->getParentOp())) {
|
||||||
parallel_op.walk([&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
|
parallel_op.walk([&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) {
|
||||||
key_op.replaceAllUsesWith(compile_op->getResult(1));
|
key_op.replaceAllUsesWith(compile_op->getResult(1));
|
||||||
key_op.erase();
|
key_op.erase();
|
||||||
|
|||||||
@ -211,7 +211,7 @@ void IdentifyXlaShardingForComputationOutputs(
|
|||||||
void IdentifyXlaShardingForTPUComputation(
|
void IdentifyXlaShardingForTPUComputation(
|
||||||
Builder* builder, tf_device::ClusterFuncOp cluster_func) {
|
Builder* builder, tf_device::ClusterFuncOp cluster_func) {
|
||||||
// Look up function definition from module.
|
// Look up function definition from module.
|
||||||
FuncOp func = cluster_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
FuncOp func = cluster_func->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||||
cluster_func.func());
|
cluster_func.func());
|
||||||
|
|
||||||
// By default inputs/outputs have maximal sharding and are assigned to logical
|
// By default inputs/outputs have maximal sharding and are assigned to logical
|
||||||
|
|||||||
@ -483,7 +483,7 @@ bool HandleHostReplicatedInputs(int64_t index,
|
|||||||
void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
|
void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
|
||||||
unsigned arg_num) {
|
unsigned arg_num) {
|
||||||
auto maybe_replicate =
|
auto maybe_replicate =
|
||||||
llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func.getParentOp());
|
llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func->getParentOp());
|
||||||
|
|
||||||
llvm::SmallVector<int64_t, 8> transform_input_indices;
|
llvm::SmallVector<int64_t, 8> transform_input_indices;
|
||||||
for (auto input : llvm::enumerate(cluster_func.operands())) {
|
for (auto input : llvm::enumerate(cluster_func.operands())) {
|
||||||
|
|||||||
@ -151,7 +151,7 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
|||||||
|
|
||||||
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
|
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
|
||||||
auto mirrored_variable_indices_attr =
|
auto mirrored_variable_indices_attr =
|
||||||
replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
|
replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
|
||||||
if (!mirrored_variable_indices_attr) return mapping;
|
if (!mirrored_variable_indices_attr) return mapping;
|
||||||
|
|
||||||
// Finds the mapping from a replicate argument to an execute operand.
|
// Finds the mapping from a replicate argument to an execute operand.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user