Use OpState::operator->() to get to member functions in Operation so we can remove the corresponding methods from OpState.
PiperOrigin-RevId: 347322042 Change-Id: I02b32db7aba1c33e4dd5510a294fcfff7b122d60
This commit is contained in:
parent
1e8c13a7f0
commit
9697081dac
@ -532,7 +532,7 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
||||
// quantization pass. These ops can be removed without losing original
|
||||
// program accuracy.
|
||||
// TODO(fengliuai): make the attribute being part of op definition.
|
||||
quantize.setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
|
||||
quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr());
|
||||
|
||||
// `original_result` has a use to `quantize`, so this will replace that use
|
||||
// by the result of `dequantize`. Remember to reset that use afterwards
|
||||
|
@ -438,7 +438,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
|
||||
}
|
||||
|
||||
LogicalResult ConvertLSTMCellSimpleToFusedLSTM::InitializeFromFuncAttributes() {
|
||||
auto attr = fused_func_op_.getAttrOfType<StringAttr>(kTFImplements);
|
||||
auto attr = fused_func_op_->getAttrOfType<StringAttr>(kTFImplements);
|
||||
if (!attr) {
|
||||
return fused_func_op_.emitError()
|
||||
<< "Invalid function attribute, expected " << kTFImplements
|
||||
@ -639,7 +639,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
||||
|
||||
// TFL lstm only supports time-majored inputs, so if it's not time-majored,
|
||||
// we will transpose the inputs and outputs.
|
||||
auto time_major_attr = func_op.getAttrOfType<BoolAttr>("tf.time_major");
|
||||
auto time_major_attr = func_op->getAttrOfType<BoolAttr>("tf.time_major");
|
||||
if (time_major_attr == nullptr) return failure();
|
||||
|
||||
bool time_majored = time_major_attr.getValue();
|
||||
@ -654,7 +654,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
||||
|
||||
// Handle go_backwards:
|
||||
// LSTM in Keras semantic will reverse the input sequence if it's go_backwards
|
||||
auto go_backwards_attr = func_op.getAttrOfType<BoolAttr>("tf.go_backwards");
|
||||
auto go_backwards_attr = func_op->getAttrOfType<BoolAttr>("tf.go_backwards");
|
||||
|
||||
if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) {
|
||||
int time_dim = time_majored ? 0 : 1;
|
||||
|
@ -1479,9 +1479,10 @@ LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) {
|
||||
if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
|
||||
|
||||
// Update convolution attributes.
|
||||
setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
|
||||
setAttr("strides", ShuffleArrayAttr(strides(), perm));
|
||||
setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||
(*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
|
||||
(*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
|
||||
(*this)->setAttr("explicit_paddings",
|
||||
ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||
|
||||
return success();
|
||||
}
|
||||
@ -1553,9 +1554,10 @@ LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) {
|
||||
if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
|
||||
|
||||
// Update convolution attributes.
|
||||
setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
|
||||
setAttr("strides", ShuffleArrayAttr(strides(), perm));
|
||||
setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||
(*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
|
||||
(*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
|
||||
(*this)->setAttr("explicit_paddings",
|
||||
ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||
|
||||
// Permute filter sizes operand.
|
||||
OpBuilder builder(getOperation());
|
||||
@ -1618,9 +1620,10 @@ LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) {
|
||||
if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure();
|
||||
|
||||
// Update convolution attributes.
|
||||
setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
|
||||
setAttr("strides", ShuffleArrayAttr(strides(), perm));
|
||||
setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||
(*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm));
|
||||
(*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm));
|
||||
(*this)->setAttr("explicit_paddings",
|
||||
ShuffleArrayAttr(explicit_paddings(), perm, 2));
|
||||
|
||||
// Permute input sizes operand.
|
||||
OpBuilder builder(getOperation());
|
||||
|
@ -370,7 +370,7 @@ LogicalResult UpdateDataFormat(StringRef data_format, Op *op) {
|
||||
if (perm.empty()) return failure();
|
||||
|
||||
// Update data format attribute.
|
||||
op->setAttr("data_format", StringAttr::get(data_format, op->getContext()));
|
||||
(*op)->setAttr("data_format", StringAttr::get(data_format, op->getContext()));
|
||||
|
||||
// Update types for all layout sensitive results.
|
||||
auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation());
|
||||
@ -421,12 +421,12 @@ LogicalResult FoldOperandsPermutation(
|
||||
GetDataFormatPermutation(op->data_format(), target_data_format);
|
||||
if (reverse_permutation.empty()) return failure();
|
||||
|
||||
op->setAttr("data_format", StringAttr::get(target_data_format, context));
|
||||
(*op)->setAttr("data_format", StringAttr::get(target_data_format, context));
|
||||
|
||||
for (auto pair : shuffle_attrs) {
|
||||
StringRef attr_name = pair.first;
|
||||
ArrayAttr attr_value = pair.second;
|
||||
op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
|
||||
(*op)->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation));
|
||||
}
|
||||
|
||||
auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation());
|
||||
|
@ -401,7 +401,7 @@ static LogicalResult Verify(ParseExampleV2Op op) {
|
||||
template <class OpClass>
|
||||
static LogicalResult VerifyPartitionedCall(OpClass op) {
|
||||
auto module = op->template getParentOfType<ModuleOp>();
|
||||
SymbolRefAttr func = op.getAttr("f").template cast<SymbolRefAttr>();
|
||||
SymbolRefAttr func = op->getAttr("f").template cast<SymbolRefAttr>();
|
||||
|
||||
auto function =
|
||||
dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(module, func));
|
||||
|
@ -342,7 +342,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
|
||||
continue;
|
||||
}
|
||||
if (func.getArgAttr(i, "tf.resource_name")) {
|
||||
if (module.getAttr("tf_saved_model.under_construction")) continue;
|
||||
if (module->getAttr("tf_saved_model.under_construction")) continue;
|
||||
return func.emitError() << "'tf.resource_name' attribute is not allowed "
|
||||
"unless it is being under construction";
|
||||
}
|
||||
@ -355,7 +355,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
|
||||
if (auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
|
||||
i, "tf_saved_model.bound_input")) {
|
||||
if (!unique_bound_inputs.insert(attr.getValue()).second) {
|
||||
if (module.getAttr("tf_saved_model.under_construction")) continue;
|
||||
if (module->getAttr("tf_saved_model.under_construction")) continue;
|
||||
return func.emitError()
|
||||
<< "duplicate 'tf_saved_model.bound_input' binding";
|
||||
}
|
||||
@ -431,7 +431,7 @@ bool IsExported(Operation *op) {
|
||||
}
|
||||
|
||||
bool HasTfSavedModelSemantics(ModuleOp module) {
|
||||
return module.getAttr("tf_saved_model.semantics") != nullptr;
|
||||
return module->getAttr("tf_saved_model.semantics") != nullptr;
|
||||
}
|
||||
|
||||
Operation *LookupBoundInput(FuncOp func, int arg_index,
|
||||
@ -483,7 +483,7 @@ class OptimizeSessionInitializerPattern
|
||||
if (to_keep.empty())
|
||||
rewriter.eraseOp(op);
|
||||
else
|
||||
op.setAttr("initializers", rewriter.getArrayAttr(to_keep));
|
||||
op->setAttr("initializers", rewriter.getArrayAttr(to_keep));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -3064,7 +3064,7 @@ Status CreateSavedModelIR(
|
||||
/*executor_type=*/builder.getStringAttr(""));
|
||||
body_builder.create<mlir::ReturnOp>(func.getLoc(), call.getResults());
|
||||
}
|
||||
func.setAttr(
|
||||
func->setAttr(
|
||||
"tf_saved_model.exported_names",
|
||||
builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
|
||||
const SavedConcreteFunction& concrete_function =
|
||||
@ -3162,7 +3162,7 @@ Status CreateSavedModelIR(
|
||||
value_attr,
|
||||
/*type=*/mlir::TypeAttr::get(type),
|
||||
/*is_mutable=*/builder.getUnitAttr());
|
||||
op.setAttr(
|
||||
op->setAttr(
|
||||
"tf_saved_model.exported_names",
|
||||
builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
|
||||
} else if (object.kind_case() == SavedObject::kConstant) {
|
||||
@ -3182,13 +3182,13 @@ Status CreateSavedModelIR(
|
||||
value_attr,
|
||||
/*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()),
|
||||
/*is_mutable=*/nullptr);
|
||||
op.setAttr(
|
||||
op->setAttr(
|
||||
"tf_saved_model.exported_names",
|
||||
builder.getStrArrayAttr(object_names.GetExportedNames(node_id)));
|
||||
}
|
||||
}
|
||||
AdjustBoundInputArgTypes(module);
|
||||
module.setAttr("tf_saved_model.semantics", builder.getUnitAttr());
|
||||
module->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
|
||||
SortSavedModelModule(module);
|
||||
MarkSavedModelFunctionVisibility(module);
|
||||
return Status::OK();
|
||||
@ -3448,7 +3448,7 @@ Status SavedModelSignatureDefImporterLite::ConvertInitializer(
|
||||
|
||||
// Set the exported name of init function to an reserved name for
|
||||
// tf_saved_model.
|
||||
init_func_op.setAttr(
|
||||
init_func_op->setAttr(
|
||||
"tf_saved_model.exported_names",
|
||||
builder.getStrArrayAttr({absl::StrCat(
|
||||
"__tf_saved_model_session_initializer_", target_node_name)}));
|
||||
@ -3508,8 +3508,8 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature(
|
||||
<< sig_def_key << ".";
|
||||
|
||||
// Use unique SignatureDef key as exported name.
|
||||
func_op.setAttr("tf_saved_model.exported_names",
|
||||
builder.getStrArrayAttr({sig_def_key}));
|
||||
func_op->setAttr("tf_saved_model.exported_names",
|
||||
builder.getStrArrayAttr({sig_def_key}));
|
||||
|
||||
// Transfer input and output parameter names to index_path attributes.
|
||||
for (auto input_and_idx : llvm::enumerate(inputs)) {
|
||||
@ -3623,7 +3623,7 @@ SavedModelSignatureDefImporterLite::ConvertSignatures() {
|
||||
builder.create<mlir::tf_saved_model::SessionInitializerOp>(
|
||||
module_->getLoc(), builder.getArrayAttr(init_sym_refs));
|
||||
|
||||
module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
|
||||
(*module_)->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
|
||||
|
||||
SortSavedModelModule(*module_);
|
||||
MarkSavedModelFunctionVisibility(*module_);
|
||||
@ -3653,7 +3653,8 @@ class SavedModelSignatureDefImporter {
|
||||
context, upgrade_legacy, /*import_restore=*/false));
|
||||
|
||||
mlir::OpBuilder builder(module->getContext());
|
||||
module->setAttr("tf_saved_model.under_construction", builder.getUnitAttr());
|
||||
(*module)->setAttr("tf_saved_model.under_construction",
|
||||
builder.getUnitAttr());
|
||||
TF_RETURN_IF_ERROR(LiftVariables(bundle, *module));
|
||||
module->removeAttr("tf_saved_model.under_construction");
|
||||
|
||||
|
@ -30,9 +30,9 @@ void PopulateTfVersions(mlir::ModuleOp module, const VersionDef& versions) {
|
||||
"bad_consumers",
|
||||
b.getI32ArrayAttr(llvm::ArrayRef<int32_t>(
|
||||
versions.bad_consumers().begin(), versions.bad_consumers().end())));
|
||||
module.setAttr("tf.versions",
|
||||
b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>(
|
||||
{producer, min_consumer, bad_consumers})));
|
||||
module->setAttr("tf.versions",
|
||||
b.getDictionaryAttr(llvm::ArrayRef<mlir::NamedAttribute>(
|
||||
{producer, min_consumer, bad_consumers})));
|
||||
}
|
||||
|
||||
mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module,
|
||||
|
@ -92,8 +92,9 @@ mlir::LogicalResult CreateSplitOp(const int num_split,
|
||||
llvm::SmallVector<mlir::Type, 4> output_types(num_split, output_type);
|
||||
*split_op = builder->create<mlir::TF::SplitOp>(
|
||||
location, output_types, split_dimension_op.output(), src_input);
|
||||
split_op->setAttr(kNumSplitAttr, builder->getIntegerAttr(
|
||||
builder->getIntegerType(32), num_split));
|
||||
(*split_op)->setAttr(
|
||||
kNumSplitAttr,
|
||||
builder->getIntegerAttr(builder->getIntegerType(32), num_split));
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
|
@ -231,7 +231,7 @@ static LogicalResult Verify(TFRFuncOp func) {
|
||||
// Collect all the undefined attributes used in the inputs.
|
||||
llvm::SmallVector<StringAttr, 4> undefined_attrs;
|
||||
for (auto attr : input_used_attrs) {
|
||||
if (!func.getAttr(attr.getValue())) {
|
||||
if (!func->getAttr(attr.getValue())) {
|
||||
undefined_attrs.push_back(attr);
|
||||
}
|
||||
}
|
||||
@ -295,7 +295,7 @@ static LogicalResult Verify(TFRFuncOp func) {
|
||||
|
||||
// Collect all the undefined attributes used in the outputs.
|
||||
for (auto attr : output_used_attrs) {
|
||||
if (!func.getAttr(attr.getValue())) {
|
||||
if (!func->getAttr(attr.getValue())) {
|
||||
undefined_attrs.push_back(attr);
|
||||
}
|
||||
}
|
||||
|
@ -111,7 +111,7 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() {
|
||||
FuncOp func = getFunction();
|
||||
SymbolTable table(external_tfr_module.hasValue()
|
||||
? *external_tfr_module
|
||||
: func.getParentOfType<ModuleOp>());
|
||||
: func->getParentOfType<ModuleOp>());
|
||||
OpBuilder builder(func);
|
||||
bool changed = false;
|
||||
func.walk([&table, &builder, &changed](Operation* op) {
|
||||
@ -244,7 +244,7 @@ LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() {
|
||||
FuncOp func = getFunction();
|
||||
SymbolTable table(external_tfr_module.hasValue()
|
||||
? *external_tfr_module
|
||||
: func.getParentOfType<ModuleOp>());
|
||||
: func->getParentOfType<ModuleOp>());
|
||||
|
||||
// The inliner only inlines the TFR call op.
|
||||
bool changed = false;
|
||||
|
@ -450,7 +450,7 @@ void RaiseToTFOpsPass::runOnFunction() {
|
||||
MLIRContext* ctx = &getContext();
|
||||
SymbolTable table(external_tfr_module.hasValue()
|
||||
? *external_tfr_module
|
||||
: func.getParentOfType<ModuleOp>());
|
||||
: func->getParentOfType<ModuleOp>());
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<RewriteTFRCallOp>(ctx, table, materialize_derived_attrs);
|
||||
|
@ -142,7 +142,7 @@ LogicalResult CopyAllowedUnregisteredAttrs(Operation* src, CallOp dst,
|
||||
|
||||
// Unregistered attribute.
|
||||
if (GetAllowedAttributes().contains(attr_name)) {
|
||||
dst.setAttr(attr.first, attr.second);
|
||||
dst->setAttr(attr.first, attr.second);
|
||||
} else {
|
||||
src->emitError("Denied unregistered attribute was found: " + attr_name);
|
||||
return failure();
|
||||
|
@ -4637,11 +4637,11 @@ class ConvertInfeedDequeueTupleOp
|
||||
if (sharding_proto.type() == ::xla::OpSharding::TUPLE) {
|
||||
*sharding_proto.add_tuple_shardings() =
|
||||
::xla::sharding_builder::AssignDevice(0);
|
||||
data_and_token.setAttr(
|
||||
data_and_token->setAttr(
|
||||
kShardingAttr,
|
||||
rewriter.getStringAttr(sharding_proto.SerializeAsString()));
|
||||
} else {
|
||||
data_and_token.setAttr(kShardingAttr, op._XlaShardingAttr());
|
||||
data_and_token->setAttr(kShardingAttr, op._XlaShardingAttr());
|
||||
}
|
||||
}
|
||||
|
||||
@ -5157,7 +5157,7 @@ class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
|
||||
/*call_target_name=*/rewriter.getStringAttr("Sharding"),
|
||||
/*has_side_effect=*/rewriter.getBoolAttr(false),
|
||||
/*backend_config=*/rewriter.getStringAttr(""));
|
||||
custom_call.setAttr(kShardingAttr, op._XlaShardingAttr());
|
||||
custom_call->setAttr(kShardingAttr, op._XlaShardingAttr());
|
||||
rewriter.replaceOp(op, custom_call.getResult(0));
|
||||
|
||||
return success();
|
||||
|
Loading…
x
Reference in New Issue
Block a user