Use OpState::operator->() to get to member functions in Operation so we can remove the corresponding methods from OpState.
PiperOrigin-RevId: 347083104 Change-Id: Iaa440d6afa09c5f3344317f83832c9e9b352ff69
This commit is contained in:
parent
bca4ea64e4
commit
d6eacc2cd3
@ -108,8 +108,8 @@ void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
|
|||||||
|
|
||||||
FuncOp outlined_func =
|
FuncOp outlined_func =
|
||||||
BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
|
BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
|
||||||
cluster_op.setAttr(builder->getIdentifier(kFuncAttr),
|
cluster_op->setAttr(builder->getIdentifier(kFuncAttr),
|
||||||
builder->getSymbolRefAttr(outlined_func.getName()));
|
builder->getSymbolRefAttr(outlined_func.getName()));
|
||||||
|
|
||||||
builder->setInsertionPoint(cluster_op);
|
builder->setInsertionPoint(cluster_op);
|
||||||
auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
|
auto cluster_func_op = builder->create<tf_device::ClusterFuncOp>(
|
||||||
|
@ -208,7 +208,7 @@ void CreateFunctions(ModuleOp module_op,
|
|||||||
StringAttr::get(metadata.result_devices[i], context));
|
StringAttr::get(metadata.result_devices[i], context));
|
||||||
}
|
}
|
||||||
|
|
||||||
func_op.setAttr(kHostAttr, StringAttr::get(host, context));
|
func_op->setAttr(kHostAttr, StringAttr::get(host, context));
|
||||||
func_op.setPublic();
|
func_op.setPublic();
|
||||||
Block *block = func_op.addEntryBlock();
|
Block *block = func_op.addEntryBlock();
|
||||||
|
|
||||||
|
@ -68,9 +68,9 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
|||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc());
|
ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc());
|
||||||
outlined_module.setAttrs(getOperation().getAttrs());
|
outlined_module->setAttrs(getOperation().getAttrs());
|
||||||
outlined_module.setAttr(SymbolTable::getSymbolAttrName(),
|
outlined_module->setAttr(SymbolTable::getSymbolAttrName(),
|
||||||
StringAttr::get(kNestedModule, ctx));
|
StringAttr::get(kNestedModule, ctx));
|
||||||
symbol_table.insert(outlined_module);
|
symbol_table.insert(outlined_module);
|
||||||
SymbolTable outlined_symbol_table(outlined_module);
|
SymbolTable outlined_symbol_table(outlined_module);
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ LogicalResult UpdateRegionReplicateVariantOps(
|
|||||||
// Map aliased devices to explicit devices based on replica.
|
// Map aliased devices to explicit devices based on replica.
|
||||||
if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
|
if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
|
||||||
if (auto device_by_replica = devices.getValue().get(launch.device()))
|
if (auto device_by_replica = devices.getValue().get(launch.device()))
|
||||||
launch.setAttr(
|
launch->setAttr(
|
||||||
kDeviceAttr,
|
kDeviceAttr,
|
||||||
device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());
|
device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());
|
||||||
|
|
||||||
|
@ -192,7 +192,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
|||||||
if (auto device = result->DeviceForResource(output)) {
|
if (auto device = result->DeviceForResource(output)) {
|
||||||
LLVM_DEBUG(llvm::dbgs()
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
<< " Setting device = " << *device << "\n");
|
<< " Setting device = " << *device << "\n");
|
||||||
identity.setAttr(kDeviceAttr, builder.getStringAttr(*device));
|
identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
|
} else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
|
||||||
|
@ -1180,7 +1180,7 @@ void UpdatePartitionedCallOpWithNewCallee(
|
|||||||
auto new_call = builder.create<CallOpType>(
|
auto new_call = builder.create<CallOpType>(
|
||||||
call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
|
call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
|
||||||
new_operands, call_op.getAttrs());
|
new_operands, call_op.getAttrs());
|
||||||
new_call.setAttr(
|
new_call->setAttr(
|
||||||
"f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
|
"f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
|
||||||
AddLoadsStoresOutsideControlFlowOp(
|
AddLoadsStoresOutsideControlFlowOp(
|
||||||
new_call, lifting_info.arg_data_type_and_updated_output_index);
|
new_call, lifting_info.arg_data_type_and_updated_output_index);
|
||||||
|
@ -94,8 +94,8 @@ void RewriteTPUEmbeddingOps::runOnFunction() {
|
|||||||
|
|
||||||
auto new_send_op = AddOperandAndRewriteAs<_SendTPUEmbeddingGradientsOp>(
|
auto new_send_op = AddOperandAndRewriteAs<_SendTPUEmbeddingGradientsOp>(
|
||||||
send_op, dedup_op, &builder);
|
send_op, dedup_op, &builder);
|
||||||
new_send_op.setAttr(new_send_op.getOperandSegmentSizeAttr(),
|
new_send_op->setAttr(new_send_op.getOperandSegmentSizeAttr(),
|
||||||
operand_size_attr);
|
operand_size_attr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -307,7 +307,7 @@ LogicalResult HandlePartitionedCallOp(
|
|||||||
auto new_call = builder.create<CallOp>(
|
auto new_call = builder.create<CallOp>(
|
||||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||||
new_operands, call.getAttrs());
|
new_operands, call.getAttrs());
|
||||||
new_call.setAttr(
|
new_call->setAttr(
|
||||||
"f", builder.getSymbolRefAttr(
|
"f", builder.getSymbolRefAttr(
|
||||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||||
for (int64_t i = 0; i < call.getNumResults(); ++i) {
|
for (int64_t i = 0; i < call.getNumResults(); ++i) {
|
||||||
|
@ -752,7 +752,7 @@ LogicalResult HandlePartitionedCallOp(
|
|||||||
auto new_call = builder.create<CallOp>(
|
auto new_call = builder.create<CallOp>(
|
||||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||||
new_operands, call.getAttrs());
|
new_operands, call.getAttrs());
|
||||||
new_call.setAttr(
|
new_call->setAttr(
|
||||||
"f", builder.getSymbolRefAttr(
|
"f", builder.getSymbolRefAttr(
|
||||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||||
for (const auto& entry : info.ret_forward_input) {
|
for (const auto& entry : info.ret_forward_input) {
|
||||||
|
@ -457,7 +457,7 @@ LogicalResult HandlePartitionedCallOp(
|
|||||||
auto new_call = builder.create<CallOp>(
|
auto new_call = builder.create<CallOp>(
|
||||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||||
new_operands, call.getAttrs());
|
new_operands, call.getAttrs());
|
||||||
new_call.setAttr(
|
new_call->setAttr(
|
||||||
"f", builder.getSymbolRefAttr(
|
"f", builder.getSymbolRefAttr(
|
||||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||||
for (const auto& entry : info.buffer_ret_to_size_ret) {
|
for (const auto& entry : info.buffer_ret_to_size_ret) {
|
||||||
|
@ -419,12 +419,12 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
|||||||
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
|
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
|
||||||
replicated_inputs, packed_inputs, cluster.getResultTypes());
|
replicated_inputs, packed_inputs, cluster.getResultTypes());
|
||||||
if (has_replicated_input_index)
|
if (has_replicated_input_index)
|
||||||
replicate_op.setAttr(kReplicatedInputIndicesAttr,
|
replicate_op->setAttr(kReplicatedInputIndicesAttr,
|
||||||
builder.getI64ArrayAttr(replicated_input_indices));
|
builder.getI64ArrayAttr(replicated_input_indices));
|
||||||
|
|
||||||
if (!mirrored_variable_indices.empty())
|
if (!mirrored_variable_indices.empty())
|
||||||
replicate_op.setAttr(kMirroredVariableIndicesAttr,
|
replicate_op->setAttr(kMirroredVariableIndicesAttr,
|
||||||
builder.getI64ArrayAttr(mirrored_variable_indices));
|
builder.getI64ArrayAttr(mirrored_variable_indices));
|
||||||
|
|
||||||
// Replace replicated cluster results with replicate op results.
|
// Replace replicated cluster results with replicate op results.
|
||||||
for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
|
for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
|
||||||
@ -550,7 +550,7 @@ LogicalResult FormClustersInBlock(
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
|
// Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
|
||||||
cluster.setAttrs(cluster_metadata->second);
|
cluster->setAttrs(cluster_metadata->second);
|
||||||
// Exclude `num_replicas` as cluster should be replicated if necessary.
|
// Exclude `num_replicas` as cluster should be replicated if necessary.
|
||||||
cluster.removeAttr(kNumReplicasAttr);
|
cluster.removeAttr(kNumReplicasAttr);
|
||||||
}
|
}
|
||||||
|
@ -76,8 +76,8 @@ class TPUCompileOpReplicationPass
|
|||||||
builder.create<TF::TPUCompileSucceededAssertOp>(
|
builder.create<TF::TPUCompileSucceededAssertOp>(
|
||||||
new_compile_op->getLoc(),
|
new_compile_op->getLoc(),
|
||||||
new_compile_op->getResult(kStatusResultIndex));
|
new_compile_op->getResult(kStatusResultIndex));
|
||||||
new_assert_op.setAttr(kDeviceAttr,
|
new_assert_op->setAttr(kDeviceAttr,
|
||||||
new_compile_op->getAttr(kDeviceAttr));
|
new_compile_op->getAttr(kDeviceAttr));
|
||||||
}
|
}
|
||||||
// Updates the operand to use the result of the newly created
|
// Updates the operand to use the result of the newly created
|
||||||
// tf._TPUCompileMlir op.
|
// tf._TPUCompileMlir op.
|
||||||
|
@ -198,7 +198,7 @@ void PropagateDevicesInGraph(
|
|||||||
if (auto sink =
|
if (auto sink =
|
||||||
llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
|
llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
|
||||||
auto source = sink.GetSource();
|
auto source = sink.GetSource();
|
||||||
source.setAttr(kDeviceAttr, new_device_attr);
|
source->setAttr(kDeviceAttr, new_device_attr);
|
||||||
PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
|
PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
|
||||||
value_to_device);
|
value_to_device);
|
||||||
updated_next_iteration = true;
|
updated_next_iteration = true;
|
||||||
|
@ -172,7 +172,7 @@ void HandleInput(Value input, const int64_t execute_arg_index,
|
|||||||
builder.setInsertionPoint(execute_launch);
|
builder.setInsertionPoint(execute_launch);
|
||||||
auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
|
auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
|
||||||
get_layout, input, &builder);
|
get_layout, input, &builder);
|
||||||
copy_with_layout.setAttr(kDeviceAttr, execute_launch.deviceAttr());
|
copy_with_layout->setAttr(kDeviceAttr, execute_launch.deviceAttr());
|
||||||
execute.setOperand(execute_arg_index, copy_with_layout);
|
execute.setOperand(execute_arg_index, copy_with_layout);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -206,8 +206,8 @@ bool HandleReplicatedInputs(
|
|||||||
.getValue()
|
.getValue()
|
||||||
.get(execute_launch.getDevice())
|
.get(execute_launch.getDevice())
|
||||||
.cast<ArrayAttr>();
|
.cast<ArrayAttr>();
|
||||||
copy_with_layout.setAttr(kDeviceAttr,
|
copy_with_layout->setAttr(kDeviceAttr,
|
||||||
device_list.getValue()[entry.index()]);
|
device_list.getValue()[entry.index()]);
|
||||||
|
|
||||||
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
|
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
|
||||||
copy_with_layout);
|
copy_with_layout);
|
||||||
@ -274,8 +274,8 @@ void HandleCompileAndExecutes(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (metadata_updated)
|
if (metadata_updated)
|
||||||
compile.setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
|
compile->setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
|
||||||
compile.getContext()));
|
compile.getContext()));
|
||||||
}
|
}
|
||||||
|
|
||||||
void TPUDynamicLayoutPass::runOnFunction(
|
void TPUDynamicLayoutPass::runOnFunction(
|
||||||
|
@ -94,7 +94,7 @@ LogicalResult GetRemappedPaddings(
|
|||||||
.str();
|
.str();
|
||||||
};
|
};
|
||||||
|
|
||||||
Attribute padding_map_attr = cluster_func.getAttr(kPaddingMapAttr);
|
Attribute padding_map_attr = cluster_func->getAttr(kPaddingMapAttr);
|
||||||
if (!padding_map_attr) return success();
|
if (!padding_map_attr) return success();
|
||||||
|
|
||||||
auto padding_map = padding_map_attr.dyn_cast<ArrayAttr>();
|
auto padding_map = padding_map_attr.dyn_cast<ArrayAttr>();
|
||||||
|
@ -114,7 +114,7 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
|
|||||||
OwningModuleRef module_for_func =
|
OwningModuleRef module_for_func =
|
||||||
ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
|
ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
|
||||||
auto parent_module = entry_func->getParentOfType<ModuleOp>();
|
auto parent_module = entry_func->getParentOfType<ModuleOp>();
|
||||||
auto versions_attr = parent_module.getAttr(kVersionsAttr);
|
auto versions_attr = parent_module->getAttr(kVersionsAttr);
|
||||||
if (!versions_attr)
|
if (!versions_attr)
|
||||||
return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
|
return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
|
||||||
|
|
||||||
@ -457,7 +457,7 @@ void AssignDevicesToReplicate(
|
|||||||
tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
|
tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
|
||||||
}
|
}
|
||||||
|
|
||||||
replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
|
replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a `tf.TPUExecute` op that executes TPU program.
|
// Creates a `tf.TPUExecute` op that executes TPU program.
|
||||||
|
@ -130,8 +130,8 @@ void IdentifyXlaShardingForComputationInputs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cluster_func_op.setAttr(tensorflow::kInputShardingAttr,
|
cluster_func_op->setAttr(tensorflow::kInputShardingAttr,
|
||||||
builder->getStrArrayAttr(sharding_for_args));
|
builder->getStrArrayAttr(sharding_for_args));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finds XlaSharding op connected to a result value. XlaSharding op may be
|
// Finds XlaSharding op connected to a result value. XlaSharding op may be
|
||||||
@ -202,8 +202,8 @@ void IdentifyXlaShardingForComputationOutputs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cluster_func.setAttr(tensorflow::kOutputShardingAttr,
|
cluster_func->setAttr(tensorflow::kOutputShardingAttr,
|
||||||
builder->getStrArrayAttr(sharding_for_rets));
|
builder->getStrArrayAttr(sharding_for_rets));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extracts input/output sharding configuration of `cluster_func` by parsing
|
// Extracts input/output sharding configuration of `cluster_func` by parsing
|
||||||
|
@ -200,7 +200,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) {
|
|||||||
});
|
});
|
||||||
// TODO(b/157276506): change type of strides to DenseElementsAttr
|
// TODO(b/157276506): change type of strides to DenseElementsAttr
|
||||||
auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
|
auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
|
||||||
conv2d.setAttr("strides", strides);
|
conv2d->setAttr("strides", strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transforms input shape for the first convolution.
|
// Transforms input shape for the first convolution.
|
||||||
|
@ -128,8 +128,8 @@ LogicalResult UpdateEmbeddingEnqueueOpInput(
|
|||||||
auto outside_compilation_attr =
|
auto outside_compilation_attr =
|
||||||
embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
|
embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
|
||||||
if (outside_compilation_attr)
|
if (outside_compilation_attr)
|
||||||
enqueue_mode.setAttr(kXlaOutsideCompilationAttr,
|
enqueue_mode->setAttr(kXlaOutsideCompilationAttr,
|
||||||
outside_compilation_attr);
|
outside_compilation_attr);
|
||||||
|
|
||||||
mode_enqueue_operand.set(enqueue_mode);
|
mode_enqueue_operand.set(enqueue_mode);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user