Use OpState::operator->() to get to member functions in Operation so we can remove the corresponding methods from OpState.

PiperOrigin-RevId: 347083104
Change-Id: Iaa440d6afa09c5f3344317f83832c9e9b352ff69
This commit is contained in:
Christian Sigg 2020-12-11 15:06:17 -08:00 committed by TensorFlower Gardener
parent bca4ea64e4
commit d6eacc2cd3
19 changed files with 37 additions and 37 deletions

View File

@ -108,7 +108,7 @@ void OutlineCluster(tf_device::ClusterOp cluster_op, SymbolTable* symbol_table,
FuncOp outlined_func = FuncOp outlined_func =
BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder); BuildFunction(live_ins.getArrayRef(), cluster_op, symbol_table, builder);
cluster_op.setAttr(builder->getIdentifier(kFuncAttr), cluster_op->setAttr(builder->getIdentifier(kFuncAttr),
builder->getSymbolRefAttr(outlined_func.getName())); builder->getSymbolRefAttr(outlined_func.getName()));
builder->setInsertionPoint(cluster_op); builder->setInsertionPoint(cluster_op);

View File

@ -208,7 +208,7 @@ void CreateFunctions(ModuleOp module_op,
StringAttr::get(metadata.result_devices[i], context)); StringAttr::get(metadata.result_devices[i], context));
} }
func_op.setAttr(kHostAttr, StringAttr::get(host, context)); func_op->setAttr(kHostAttr, StringAttr::get(host, context));
func_op.setPublic(); func_op.setPublic();
Block *block = func_op.addEntryBlock(); Block *block = func_op.addEntryBlock();

View File

@ -68,8 +68,8 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() {
return signalPassFailure(); return signalPassFailure();
} }
ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc()); ModuleOp outlined_module = ModuleOp::create(getOperation().getLoc());
outlined_module.setAttrs(getOperation().getAttrs()); outlined_module->setAttrs(getOperation().getAttrs());
outlined_module.setAttr(SymbolTable::getSymbolAttrName(), outlined_module->setAttr(SymbolTable::getSymbolAttrName(),
StringAttr::get(kNestedModule, ctx)); StringAttr::get(kNestedModule, ctx));
symbol_table.insert(outlined_module); symbol_table.insert(outlined_module);
SymbolTable outlined_symbol_table(outlined_module); SymbolTable outlined_symbol_table(outlined_module);

View File

@ -121,7 +121,7 @@ LogicalResult UpdateRegionReplicateVariantOps(
// Map aliased devices to explicit devices based on replica. // Map aliased devices to explicit devices based on replica.
if (auto launch = dyn_cast<tf_device::LaunchOp>(op)) if (auto launch = dyn_cast<tf_device::LaunchOp>(op))
if (auto device_by_replica = devices.getValue().get(launch.device())) if (auto device_by_replica = devices.getValue().get(launch.device()))
launch.setAttr( launch->setAttr(
kDeviceAttr, kDeviceAttr,
device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>()); device_by_replica.cast<ArrayAttr>()[replica_id].cast<StringAttr>());

View File

@ -192,7 +192,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
if (auto device = result->DeviceForResource(output)) { if (auto device = result->DeviceForResource(output)) {
LLVM_DEBUG(llvm::dbgs() LLVM_DEBUG(llvm::dbgs()
<< " Setting device = " << *device << "\n"); << " Setting device = " << *device << "\n");
identity.setAttr(kDeviceAttr, builder.getStringAttr(*device)); identity->setAttr(kDeviceAttr, builder.getStringAttr(*device));
} }
} }
} else if (auto while_region = dyn_cast<WhileRegionOp>(op)) { } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {

View File

@ -1180,7 +1180,7 @@ void UpdatePartitionedCallOpWithNewCallee(
auto new_call = builder.create<CallOpType>( auto new_call = builder.create<CallOpType>(
call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(), call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(),
new_operands, call_op.getAttrs()); new_operands, call_op.getAttrs());
new_call.setAttr( new_call->setAttr(
"f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName())); "f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName()));
AddLoadsStoresOutsideControlFlowOp( AddLoadsStoresOutsideControlFlowOp(
new_call, lifting_info.arg_data_type_and_updated_output_index); new_call, lifting_info.arg_data_type_and_updated_output_index);

View File

@ -94,7 +94,7 @@ void RewriteTPUEmbeddingOps::runOnFunction() {
auto new_send_op = AddOperandAndRewriteAs<_SendTPUEmbeddingGradientsOp>( auto new_send_op = AddOperandAndRewriteAs<_SendTPUEmbeddingGradientsOp>(
send_op, dedup_op, &builder); send_op, dedup_op, &builder);
new_send_op.setAttr(new_send_op.getOperandSegmentSizeAttr(), new_send_op->setAttr(new_send_op.getOperandSegmentSizeAttr(),
operand_size_attr); operand_size_attr);
} }
} }

View File

@ -307,7 +307,7 @@ LogicalResult HandlePartitionedCallOp(
auto new_call = builder.create<CallOp>( auto new_call = builder.create<CallOp>(
call.getLoc(), info.decomposed_callee.getType().getResults(), call.getLoc(), info.decomposed_callee.getType().getResults(),
new_operands, call.getAttrs()); new_operands, call.getAttrs());
new_call.setAttr( new_call->setAttr(
"f", builder.getSymbolRefAttr( "f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(info.decomposed_callee).getName())); const_cast<FuncOp&>(info.decomposed_callee).getName()));
for (int64_t i = 0; i < call.getNumResults(); ++i) { for (int64_t i = 0; i < call.getNumResults(); ++i) {

View File

@ -752,7 +752,7 @@ LogicalResult HandlePartitionedCallOp(
auto new_call = builder.create<CallOp>( auto new_call = builder.create<CallOp>(
call.getLoc(), info.decomposed_callee.getType().getResults(), call.getLoc(), info.decomposed_callee.getType().getResults(),
new_operands, call.getAttrs()); new_operands, call.getAttrs());
new_call.setAttr( new_call->setAttr(
"f", builder.getSymbolRefAttr( "f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(info.decomposed_callee).getName())); const_cast<FuncOp&>(info.decomposed_callee).getName()));
for (const auto& entry : info.ret_forward_input) { for (const auto& entry : info.ret_forward_input) {

View File

@ -457,7 +457,7 @@ LogicalResult HandlePartitionedCallOp(
auto new_call = builder.create<CallOp>( auto new_call = builder.create<CallOp>(
call.getLoc(), info.decomposed_callee.getType().getResults(), call.getLoc(), info.decomposed_callee.getType().getResults(),
new_operands, call.getAttrs()); new_operands, call.getAttrs());
new_call.setAttr( new_call->setAttr(
"f", builder.getSymbolRefAttr( "f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(info.decomposed_callee).getName())); const_cast<FuncOp&>(info.decomposed_callee).getName()));
for (const auto& entry : info.buffer_ret_to_size_ret) { for (const auto& entry : info.buffer_ret_to_size_ret) {

View File

@ -419,11 +419,11 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(), llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
replicated_inputs, packed_inputs, cluster.getResultTypes()); replicated_inputs, packed_inputs, cluster.getResultTypes());
if (has_replicated_input_index) if (has_replicated_input_index)
replicate_op.setAttr(kReplicatedInputIndicesAttr, replicate_op->setAttr(kReplicatedInputIndicesAttr,
builder.getI64ArrayAttr(replicated_input_indices)); builder.getI64ArrayAttr(replicated_input_indices));
if (!mirrored_variable_indices.empty()) if (!mirrored_variable_indices.empty())
replicate_op.setAttr(kMirroredVariableIndicesAttr, replicate_op->setAttr(kMirroredVariableIndicesAttr,
builder.getI64ArrayAttr(mirrored_variable_indices)); builder.getI64ArrayAttr(mirrored_variable_indices));
// Replace replicated cluster results with replicate op results. // Replace replicated cluster results with replicate op results.
@ -550,7 +550,7 @@ LogicalResult FormClustersInBlock(
return failure(); return failure();
// Copy TPUReplicateMetadata attributes to `tf_device.cluster`. // Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
cluster.setAttrs(cluster_metadata->second); cluster->setAttrs(cluster_metadata->second);
// Exclude `num_replicas` as cluster should be replicated if necessary. // Exclude `num_replicas` as cluster should be replicated if necessary.
cluster.removeAttr(kNumReplicasAttr); cluster.removeAttr(kNumReplicasAttr);
} }

View File

@ -76,7 +76,7 @@ class TPUCompileOpReplicationPass
builder.create<TF::TPUCompileSucceededAssertOp>( builder.create<TF::TPUCompileSucceededAssertOp>(
new_compile_op->getLoc(), new_compile_op->getLoc(),
new_compile_op->getResult(kStatusResultIndex)); new_compile_op->getResult(kStatusResultIndex));
new_assert_op.setAttr(kDeviceAttr, new_assert_op->setAttr(kDeviceAttr,
new_compile_op->getAttr(kDeviceAttr)); new_compile_op->getAttr(kDeviceAttr));
} }
// Updates the operand to use the result of the newly created // Updates the operand to use the result of the newly created

View File

@ -198,7 +198,7 @@ void PropagateDevicesInGraph(
if (auto sink = if (auto sink =
llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) { llvm::dyn_cast<tf_executor::NextIterationSinkOp>(op_to_update)) {
auto source = sink.GetSource(); auto source = sink.GetSource();
source.setAttr(kDeviceAttr, new_device_attr); source->setAttr(kDeviceAttr, new_device_attr);
PopulateDeviceForOpResults(*source, new_device_attr.getValue(), PopulateDeviceForOpResults(*source, new_device_attr.getValue(),
value_to_device); value_to_device);
updated_next_iteration = true; updated_next_iteration = true;

View File

@ -172,7 +172,7 @@ void HandleInput(Value input, const int64_t execute_arg_index,
builder.setInsertionPoint(execute_launch); builder.setInsertionPoint(execute_launch);
auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch, auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
get_layout, input, &builder); get_layout, input, &builder);
copy_with_layout.setAttr(kDeviceAttr, execute_launch.deviceAttr()); copy_with_layout->setAttr(kDeviceAttr, execute_launch.deviceAttr());
execute.setOperand(execute_arg_index, copy_with_layout); execute.setOperand(execute_arg_index, copy_with_layout);
} }
@ -206,7 +206,7 @@ bool HandleReplicatedInputs(
.getValue() .getValue()
.get(execute_launch.getDevice()) .get(execute_launch.getDevice())
.cast<ArrayAttr>(); .cast<ArrayAttr>();
copy_with_layout.setAttr(kDeviceAttr, copy_with_layout->setAttr(kDeviceAttr,
device_list.getValue()[entry.index()]); device_list.getValue()[entry.index()]);
replicate.setOperand(num_replicas * replicate_arg_index + entry.index(), replicate.setOperand(num_replicas * replicate_arg_index + entry.index(),
@ -274,7 +274,7 @@ void HandleCompileAndExecutes(
} }
if (metadata_updated) if (metadata_updated)
compile.setAttr("metadata", StringAttr::get(metadata.SerializeAsString(), compile->setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
compile.getContext())); compile.getContext()));
} }

View File

@ -94,7 +94,7 @@ LogicalResult GetRemappedPaddings(
.str(); .str();
}; };
Attribute padding_map_attr = cluster_func.getAttr(kPaddingMapAttr); Attribute padding_map_attr = cluster_func->getAttr(kPaddingMapAttr);
if (!padding_map_attr) return success(); if (!padding_map_attr) return success();
auto padding_map = padding_map_attr.dyn_cast<ArrayAttr>(); auto padding_map = padding_map_attr.dyn_cast<ArrayAttr>();

View File

@ -114,7 +114,7 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
OwningModuleRef module_for_func = OwningModuleRef module_for_func =
ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext())); ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
auto parent_module = entry_func->getParentOfType<ModuleOp>(); auto parent_module = entry_func->getParentOfType<ModuleOp>();
auto versions_attr = parent_module.getAttr(kVersionsAttr); auto versions_attr = parent_module->getAttr(kVersionsAttr);
if (!versions_attr) if (!versions_attr)
return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr)); return parent_module.emitError(CreateMissingAttributeMsg(kVersionsAttr));
@ -457,7 +457,7 @@ void AssignDevicesToReplicate(
tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts))); tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
} }
replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs)); replicate->setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
} }
// Creates a `tf.TPUExecute` op that executes TPU program. // Creates a `tf.TPUExecute` op that executes TPU program.

View File

@ -130,7 +130,7 @@ void IdentifyXlaShardingForComputationInputs(
} }
} }
cluster_func_op.setAttr(tensorflow::kInputShardingAttr, cluster_func_op->setAttr(tensorflow::kInputShardingAttr,
builder->getStrArrayAttr(sharding_for_args)); builder->getStrArrayAttr(sharding_for_args));
} }
@ -202,7 +202,7 @@ void IdentifyXlaShardingForComputationOutputs(
} }
} }
cluster_func.setAttr(tensorflow::kOutputShardingAttr, cluster_func->setAttr(tensorflow::kOutputShardingAttr,
builder->getStrArrayAttr(sharding_for_rets)); builder->getStrArrayAttr(sharding_for_rets));
} }

View File

@ -200,7 +200,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) {
}); });
// TODO(b/157276506): change type of strides to DenseElementsAttr // TODO(b/157276506): change type of strides to DenseElementsAttr
auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context); auto strides = ArrayAttr::get(llvm::to_vector<4>(attrs), context);
conv2d.setAttr("strides", strides); conv2d->setAttr("strides", strides);
} }
// Transforms input shape for the first convolution. // Transforms input shape for the first convolution.

View File

@ -128,7 +128,7 @@ LogicalResult UpdateEmbeddingEnqueueOpInput(
auto outside_compilation_attr = auto outside_compilation_attr =
embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr); embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
if (outside_compilation_attr) if (outside_compilation_attr)
enqueue_mode.setAttr(kXlaOutsideCompilationAttr, enqueue_mode->setAttr(kXlaOutsideCompilationAttr,
outside_compilation_attr); outside_compilation_attr);
mode_enqueue_operand.set(enqueue_mode); mode_enqueue_operand.set(enqueue_mode);