Use OpState::operator->() to get to member functions in Operation so we can remove the corresponding methods from OpState.
PiperOrigin-RevId: 347083104 Change-Id: Iaa440d6afa09c5f3344317f83832c9e9b352ff69
This commit is contained in:
parent
bca4ea64e4
commit
d6eacc2cd3
@ -108,7 +108,7 @@ void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
|
||||
|
||||
FuncOp outlined_func =
|
||||
BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
|
||||
cluster_op.setAttr(builder->getIdentifier(kFuncAttr),
|
||||
cluster_op->setAttr(builder->getIdentifier(kFuncAttr),
|
||||
builder->getSymbolRefAttr(outlined_func.getName()));
|
||||
|
||||
builder->setInsertionPoint(cluster_op);
|
||||
|
@ -208,7 +208,7 @@ void CreateFunctions(ModuleOp module_op,
|
||||
StringAttr::get(metadata.result_devices[i], context));
|
||||
}
|
||||
|
||||
func_op.setAttr(kHostAttr, StringAttr::get(host, context));
|
||||
func_op->setAttr(kHostAttr, StringAttr::get(host, context));
|
||||
func_op.setPublic();
|
||||
Block *block = func_op.addEntryBlock();
|
||||
|
||||
|
@ -68,8 +68,8 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
|
||||
return signalPassFailure();
|
||||
}
|
||||
ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc());
|
||||
outlined_module.setAttrs(getOperation().getAttrs());
|
||||
outlined_module.setAttr(SymbolTable::getSymbolAttrName(),
|
||||
outlined_module->setAttrs(getOperation().getAttrs());
|
||||
outlined_module->setAttr(SymbolTable::getSymbolAttrName(),
|
||||
StringAttr::get(kNestedModule, ctx));
|
||||
symbol_table.insert(outlined_module);
|
||||
SymbolTable outlined_symbol_table(outlined_module);
|
||||
|
@ -121,7 +121,7 @@ LogicalResult UpdateRegionReplicateVariantOps(
|
||||
// Map aliased devices to explicit devices based on replica.
|
||||
if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
|
||||
if (auto device_by_replica = devices.getValue().get(launch.device()))
|
||||
launch.setAttr(
|
||||
launch->setAttr(
|
||||
kDeviceAttr,
|
||||
device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());
|
||||
|
||||
|
@ -192,7 +192,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
||||
if (auto device = result->DeviceForResource(output)) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " Setting device = " << *device << "\n");
|
||||
identity.setAttr(kDeviceAttr, builder.getStringAttr(*device));
|
||||
identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
|
||||
}
|
||||
}
|
||||
} else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
|
||||
|
@ -1180,7 +1180,7 @@ void UpdatePartitionedCallOpWithNewCallee(
|
||||
auto new_call = builder.create<CallOpType>(
|
||||
call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
|
||||
new_operands, call_op.getAttrs());
|
||||
new_call.setAttr(
|
||||
new_call->setAttr(
|
||||
"f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
|
||||
AddLoadsStoresOutsideControlFlowOp(
|
||||
new_call, lifting_info.arg_data_type_and_updated_output_index);
|
||||
|
@ -94,7 +94,7 @@ void RewriteTPUEmbeddingOps::runOnFunction() {
|
||||
|
||||
auto new_send_op = AddOperandAndRewriteAs<_SendTPUEmbeddingGradientsOp>(
|
||||
send_op, dedup_op, &builder);
|
||||
new_send_op.setAttr(new_send_op.getOperandSegmentSizeAttr(),
|
||||
new_send_op->setAttr(new_send_op.getOperandSegmentSizeAttr(),
|
||||
operand_size_attr);
|
||||
}
|
||||
}
|
||||
|
@ -307,7 +307,7 @@ LogicalResult HandlePartitionedCallOp(
|
||||
auto new_call = builder.create<CallOp>(
|
||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||
new_operands, call.getAttrs());
|
||||
new_call.setAttr(
|
||||
new_call->setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||
for (int64_t i = 0; i < call.getNumResults(); ++i) {
|
||||
|
@ -752,7 +752,7 @@ LogicalResult HandlePartitionedCallOp(
|
||||
auto new_call = builder.create<CallOp>(
|
||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||
new_operands, call.getAttrs());
|
||||
new_call.setAttr(
|
||||
new_call->setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||
for (const auto& entry : info.ret_forward_input) {
|
||||
|
@ -457,7 +457,7 @@ LogicalResult HandlePartitionedCallOp(
|
||||
auto new_call = builder.create<CallOp>(
|
||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||
new_operands, call.getAttrs());
|
||||
new_call.setAttr(
|
||||
new_call->setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||
for (const auto& entry : info.buffer_ret_to_size_ret) {
|
||||
|
@ -419,11 +419,11 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
||||
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
|
||||
replicated_inputs, packed_inputs, cluster.getResultTypes());
|
||||
if (has_replicated_input_index)
|
||||
replicate_op.setAttr(kReplicatedInputIndicesAttr,
|
||||
replicate_op->setAttr(kReplicatedInputIndicesAttr,
|
||||
builder.getI64ArrayAttr(replicated_input_indices));
|
||||
|
||||
if (!mirrored_variable_indices.empty())
|
||||
replicate_op.setAttr(kMirroredVariableIndicesAttr,
|
||||
replicate_op->setAttr(kMirroredVariableIndicesAttr,
|
||||
builder.getI64ArrayAttr(mirrored_variable_indices));
|
||||
|
||||
// Replace replicated cluster results with replicate op results.
|
||||
@ -550,7 +550,7 @@ LogicalResult FormClustersInBlock(
|
||||
return failure();
|
||||
|
||||
// Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
|
||||
cluster.setAttrs(cluster_metadata->second);
|
||||
cluster->setAttrs(cluster_metadata->second);
|
||||
// Exclude `num_replicas` as cluster should be replicated if necessary.
|
||||
cluster.removeAttr(kNumReplicasAttr);
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ class TPUCompileOpReplicationPass
|
||||
builder.create<TF::TPUCompileSucceededAssertOp>(
|
||||
new_compile_op->getLoc(),
|
||||
new_compile_op->getResult(kStatusResultIndex));
|
||||
new_assert_op.setAttr(kDeviceAttr,
|
||||
new_assert_op->setAttr(kDeviceAttr,
|
||||
new_compile_op->getAttr(kDeviceAttr));
|
||||
}
|
||||
// Updates the operand to use the result of the newly created
|
||||
|
@ -198,7 +198,7 @@ void PropagateDevicesInGraph(
|
||||
if (auto sink =
|
||||
llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
|
||||
auto source = sink.GetSource();
|
||||
source.setAttr(kDeviceAttr, new_device_attr);
|
||||
source->setAttr(kDeviceAttr, new_device_attr);
|
||||
PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
|
||||
value_to_device);
|
||||
updated_next_iteration = true;
|
||||
|
@ -172,7 +172,7 @@ void HandleInput(Value input, const int64_t execute_arg_index,
|
||||
builder.setInsertionPoint(execute_launch);
|
||||
auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
|
||||
get_layout, input, &builder);
|
||||
copy_with_layout.setAttr(kDeviceAttr, execute_launch.deviceAttr());
|
||||
copy_with_layout->setAttr(kDeviceAttr, execute_launch.deviceAttr());
|
||||
execute.setOperand(execute_arg_index, copy_with_layout);
|
||||
}
|
||||
|
||||
@ -206,7 +206,7 @@ bool HandleReplicatedInputs(
|
||||
.getValue()
|
||||
.get(execute_launch.getDevice())
|
||||
.cast<ArrayAttr>();
|
||||
copy_with_layout.setAttr(kDeviceAttr,
|
||||
copy_with_layout->setAttr(kDeviceAttr,
|
||||
device_list.getValue()[entry.index()]);
|
||||
|
||||
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
|
||||
@ -274,7 +274,7 @@ void HandleCompileAndExecutes(
|
||||
}
|
||||
|
||||
if (metadata_updated)
|
||||
compile.setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
|
||||
compile->setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
|
||||
compile.getContext()));
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,7 @@ LogicalResult GetRemappedPaddings(
|
||||
.str();
|
||||
};
|
||||
|
||||
Attribute padding_map_attr = cluster_func.getAttr(kPaddingMapAttr);
|
||||
Attribute padding_map_attr = cluster_func->getAttr(kPaddingMapAttr);
|
||||
if (!padding_map_attr) return success();
|
||||
|
||||
auto padding_map = padding_map_attr.dyn_cast<ArrayAttr>();
|
||||
|
@ -114,7 +114,7 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
|
||||
OwningModuleRef module_for_func =
|
||||
ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
|
||||
auto parent_module = entry_func->getParentOfType<ModuleOp>();
|
||||
auto versions_attr = parent_module.getAttr(kVersionsAttr);
|
||||
auto versions_attr = parent_module->getAttr(kVersionsAttr);
|
||||
if (!versions_attr)
|
||||
return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
|
||||
|
||||
@ -457,7 +457,7 @@ void AssignDevicesToReplicate(
|
||||
tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
|
||||
}
|
||||
|
||||
replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
|
||||
replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
|
||||
}
|
||||
|
||||
// Creates a `tf.TPUExecute` op that executes TPU program.
|
||||
|
@ -130,7 +130,7 @@ void IdentifyXlaShardingForComputationInputs(
|
||||
}
|
||||
}
|
||||
|
||||
cluster_func_op.setAttr(tensorflow::kInputShardingAttr,
|
||||
cluster_func_op->setAttr(tensorflow::kInputShardingAttr,
|
||||
builder->getStrArrayAttr(sharding_for_args));
|
||||
}
|
||||
|
||||
@ -202,7 +202,7 @@ void IdentifyXlaShardingForComputationOutputs(
|
||||
}
|
||||
}
|
||||
|
||||
cluster_func.setAttr(tensorflow::kOutputShardingAttr,
|
||||
cluster_func->setAttr(tensorflow::kOutputShardingAttr,
|
||||
builder->getStrArrayAttr(sharding_for_rets));
|
||||
}
|
||||
|
||||
|
@ -200,7 +200,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) {
|
||||
});
|
||||
// TODO(b/157276506): change type of strides to DenseElementsAttr
|
||||
auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
|
||||
conv2d.setAttr("strides", strides);
|
||||
conv2d->setAttr("strides", strides);
|
||||
}
|
||||
|
||||
// Transforms input shape for the first convolution.
|
||||
|
@ -128,7 +128,7 @@ LogicalResult UpdateEmbeddingEnqueueOpInput(
|
||||
auto outside_compilation_attr =
|
||||
embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
|
||||
if (outside_compilation_attr)
|
||||
enqueue_mode.setAttr(kXlaOutsideCompilationAttr,
|
||||
enqueue_mode->setAttr(kXlaOutsideCompilationAttr,
|
||||
outside_compilation_attr);
|
||||
|
||||
mode_enqueue_operand.set(enqueue_mode);
|
||||
|
Loading…
Reference in New Issue
Block a user