diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index d847a7d52e6..831a67078e1 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -532,7 +532,7 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params, // quantization pass. These ops can be removed without losing original // program accuracy. // TODO(fengliuai): make the attribute being part of op definition. - quantize.setAttr(kVolatileOpAttrName, builder_.getUnitAttr()); + quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr()); // `original_result` has a use to `quantize`, so this will replace that use // by the result of `dequantize`. Remember to reset that use afterwards diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 357079c561b..1a503675f45 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -438,7 +438,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() { } LogicalResult ConvertLSTMCellSimpleToFusedLSTM::InitializeFromFuncAttributes() { - auto attr = fused_func_op_.getAttrOfType(kTFImplements); + auto attr = fused_func_op_->getAttrOfType(kTFImplements); if (!attr) { return fused_func_op_.emitError() << "Invalid function attribute, expected " << kTFImplements @@ -639,7 +639,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { // TFL lstm only supports time-majored inputs, so if it's not time-majored, // we will transpose the inputs and outputs. - auto time_major_attr = func_op.getAttrOfType("tf.time_major"); + auto time_major_attr = func_op->getAttrOfType("tf.time_major"); if (time_major_attr == nullptr) return failure(); bool time_majored = time_major_attr.getValue(); @@ -654,7 +654,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) { // Handle go_backwards: // LSTM in Keras semantic will reverse the input sequence if it's go_backwards - auto go_backwards_attr = func_op.getAttrOfType("tf.go_backwards"); + auto go_backwards_attr = func_op->getAttrOfType("tf.go_backwards"); if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) { int time_dim = time_majored ? 0 : 1; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 513f8338343..93e0113ce4a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -1479,9 +1479,10 @@ LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) { if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); // Update convolution attributes. - setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - setAttr("strides", ShuffleArrayAttr(strides(), perm)); - setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm)); + (*this)->setAttr("explicit_paddings", + ShuffleArrayAttr(explicit_paddings(), perm, 2)); return success(); } @@ -1553,9 +1554,10 @@ LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); // Update convolution attributes. - setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - setAttr("strides", ShuffleArrayAttr(strides(), perm)); - setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm)); + (*this)->setAttr("explicit_paddings", + ShuffleArrayAttr(explicit_paddings(), perm, 2)); // Permute filter sizes operand. OpBuilder builder(getOperation()); @@ -1618,9 +1620,10 @@ LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) { if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); // Update convolution attributes. - setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - setAttr("strides", ShuffleArrayAttr(strides(), perm)); - setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm)); + (*this)->setAttr("explicit_paddings", + ShuffleArrayAttr(explicit_paddings(), perm, 2)); // Permute input sizes operand. OpBuilder builder(getOperation()); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc index de8bc8311b7..dddb9bca67f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc @@ -370,7 +370,7 @@ LogicalResult UpdateDataFormat(StringRef data_format, Op *op) { if (perm.empty()) return failure(); // Update data format attribute. - op->setAttr("data_format", StringAttr::get(data_format, op->getContext())); + (*op)->setAttr("data_format", StringAttr::get(data_format, op->getContext())); // Update types for all layout sensitive results. auto layout_sensitive = cast(op->getOperation()); @@ -421,12 +421,12 @@ LogicalResult FoldOperandsPermutation( GetDataFormatPermutation(op->data_format(), target_data_format); if (reverse_permutation.empty()) return failure(); - op->setAttr("data_format", StringAttr::get(target_data_format, context)); + (*op)->setAttr("data_format", StringAttr::get(target_data_format, context)); for (auto pair : shuffle_attrs) { StringRef attr_name = pair.first; ArrayAttr attr_value = pair.second; - op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation)); + (*op)->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation)); } auto fold = cast(op->getOperation()); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 5d681295f61..0208e377d19 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -401,7 +401,7 @@ static LogicalResult Verify(ParseExampleV2Op op) { template static LogicalResult VerifyPartitionedCall(OpClass op) { auto module = op->template getParentOfType(); - SymbolRefAttr func = op.getAttr("f").template cast(); + SymbolRefAttr func = op->getAttr("f").template cast(); auto function = dyn_cast_or_null(SymbolTable::lookupSymbolIn(module, func)); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 85cb8edb8c7..3edcbf505dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -342,7 +342,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) { continue; } if (func.getArgAttr(i, "tf.resource_name")) { - if (module.getAttr("tf_saved_model.under_construction")) continue; + if (module->getAttr("tf_saved_model.under_construction")) continue; return func.emitError() << "'tf.resource_name' attribute is not allowed " "unless it is being under construction"; } @@ -355,7 +355,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) { if (auto attr = func.getArgAttrOfType( i, "tf_saved_model.bound_input")) { if (!unique_bound_inputs.insert(attr.getValue()).second) { - if (module.getAttr("tf_saved_model.under_construction")) continue; + if (module->getAttr("tf_saved_model.under_construction")) continue; return func.emitError() << "duplicate 'tf_saved_model.bound_input' binding"; } @@ -431,7 +431,7 @@ bool IsExported(Operation *op) { } bool HasTfSavedModelSemantics(ModuleOp module) { - return module.getAttr("tf_saved_model.semantics") != nullptr; + return module->getAttr("tf_saved_model.semantics") != nullptr; } Operation *LookupBoundInput(FuncOp func, int arg_index, @@ -483,7 +483,7 @@ class OptimizeSessionInitializerPattern if (to_keep.empty()) rewriter.eraseOp(op); else - op.setAttr("initializers", rewriter.getArrayAttr(to_keep)); + op->setAttr("initializers", rewriter.getArrayAttr(to_keep)); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 3099554f5c7..d0ae9dc8aee 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -3064,7 +3064,7 @@ Status CreateSavedModelIR( /*executor_type=*/builder.getStringAttr("")); body_builder.create(func.getLoc(), call.getResults()); } - func.setAttr( + func->setAttr( "tf_saved_model.exported_names", builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); const SavedConcreteFunction& concrete_function = @@ -3162,7 +3162,7 @@ Status CreateSavedModelIR( value_attr, /*type=*/mlir::TypeAttr::get(type), /*is_mutable=*/builder.getUnitAttr()); - op.setAttr( + op->setAttr( "tf_saved_model.exported_names", builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); } else if (object.kind_case() == SavedObject::kConstant) { @@ -3182,13 +3182,13 @@ Status CreateSavedModelIR( value_attr, /*type=*/mlir::TypeAttr::get(value_attr.Attribute::getType()), /*is_mutable=*/nullptr); - op.setAttr( + op->setAttr( "tf_saved_model.exported_names", builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); } } AdjustBoundInputArgTypes(module); - module.setAttr("tf_saved_model.semantics", builder.getUnitAttr()); + module->setAttr("tf_saved_model.semantics", builder.getUnitAttr()); SortSavedModelModule(module); MarkSavedModelFunctionVisibility(module); return Status::OK(); @@ -3448,7 +3448,7 @@ Status SavedModelSignatureDefImporterLite::ConvertInitializer( // Set the exported name of init function to an reserved name for // tf_saved_model. - init_func_op.setAttr( + init_func_op->setAttr( "tf_saved_model.exported_names", builder.getStrArrayAttr({absl::StrCat( "__tf_saved_model_session_initializer_", target_node_name)})); @@ -3508,8 +3508,8 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature( << sig_def_key << "."; // Use unique SignatureDef key as exported name. - func_op.setAttr("tf_saved_model.exported_names", - builder.getStrArrayAttr({sig_def_key})); + func_op->setAttr("tf_saved_model.exported_names", + builder.getStrArrayAttr({sig_def_key})); // Transfer input and output parameter names to index_path attributes. for (auto input_and_idx : llvm::enumerate(inputs)) { @@ -3623,7 +3623,7 @@ SavedModelSignatureDefImporterLite::ConvertSignatures() { builder.create( module_->getLoc(), builder.getArrayAttr(init_sym_refs)); - module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr()); + (*module_)->setAttr("tf_saved_model.semantics", builder.getUnitAttr()); SortSavedModelModule(*module_); MarkSavedModelFunctionVisibility(*module_); @@ -3653,7 +3653,8 @@ class SavedModelSignatureDefImporter { context, upgrade_legacy, /*import_restore=*/false)); mlir::OpBuilder builder(module->getContext()); - module->setAttr("tf_saved_model.under_construction", builder.getUnitAttr()); + (*module)->setAttr("tf_saved_model.under_construction", + builder.getUnitAttr()); TF_RETURN_IF_ERROR(LiftVariables(bundle, *module)); module->removeAttr("tf_saved_model.under_construction"); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc index d7b9a5c2f45..075d33a348c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc @@ -30,9 +30,9 @@ void PopulateTfVersions(mlir::ModuleOp module, const VersionDef& versions) { "bad_consumers", b.getI32ArrayAttr(llvm::ArrayRef( versions.bad_consumers().begin(), versions.bad_consumers().end()))); - module.setAttr("tf.versions", - b.getDictionaryAttr(llvm::ArrayRef( - {producer, min_consumer, bad_consumers}))); + module->setAttr("tf.versions", + b.getDictionaryAttr(llvm::ArrayRef( + {producer, min_consumer, bad_consumers}))); } mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index a3f8e833ae3..82939c9d600 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -92,8 +92,9 @@ mlir::LogicalResult CreateSplitOp(const int num_split, llvm::SmallVector output_types(num_split, output_type); *split_op = builder->create( location, output_types, split_dimension_op.output(), src_input); - split_op->setAttr(kNumSplitAttr, builder->getIntegerAttr( - builder->getIntegerType(32), num_split)); + (*split_op)->setAttr( + kNumSplitAttr, + builder->getIntegerAttr(builder->getIntegerType(32), num_split)); return mlir::success(); } diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc index e1ef506ba1f..be01511510c 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -231,7 +231,7 @@ static LogicalResult Verify(TFRFuncOp func) { // Collect all the undefined attributes used in the inputs. llvm::SmallVector undefined_attrs; for (auto attr : input_used_attrs) { - if (!func.getAttr(attr.getValue())) { + if (!func->getAttr(attr.getValue())) { undefined_attrs.push_back(attr); } } @@ -295,7 +295,7 @@ static LogicalResult Verify(TFRFuncOp func) { // Collect all the undefined attributes used in the outputs. for (auto attr : output_used_attrs) { - if (!func.getAttr(attr.getValue())) { + if (!func->getAttr(attr.getValue())) { undefined_attrs.push_back(attr); } } diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc index 13d5f45e0ab..c532bc103a9 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -111,7 +111,7 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { FuncOp func = getFunction(); SymbolTable table(external_tfr_module.hasValue() ? *external_tfr_module - : func.getParentOfType()); + : func->getParentOfType()); OpBuilder builder(func); bool changed = false; func.walk([&table, &builder, &changed](Operation* op) { @@ -244,7 +244,7 @@ LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() { FuncOp func = getFunction(); SymbolTable table(external_tfr_module.hasValue() ? *external_tfr_module - : func.getParentOfType()); + : func->getParentOfType()); // The inliner only inlines the TFR call op. bool changed = false; diff --git a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc index 7ffcd4c7b11..d3780a4ef26 100644 --- a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc +++ b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc @@ -450,7 +450,7 @@ void RaiseToTFOpsPass::runOnFunction() { MLIRContext* ctx = &getContext(); SymbolTable table(external_tfr_module.hasValue() ? *external_tfr_module - : func.getParentOfType()); + : func->getParentOfType()); OwningRewritePatternList patterns; patterns.insert(ctx, table, materialize_derived_attrs); diff --git a/tensorflow/compiler/mlir/tfr/utils/utils.cc b/tensorflow/compiler/mlir/tfr/utils/utils.cc index 253a109358b..2dec56074af 100644 --- a/tensorflow/compiler/mlir/tfr/utils/utils.cc +++ b/tensorflow/compiler/mlir/tfr/utils/utils.cc @@ -142,7 +142,7 @@ LogicalResult CopyAllowedUnregisteredAttrs(Operation* src, CallOp dst, // Unregistered attribute. if (GetAllowedAttributes().contains(attr_name)) { - dst.setAttr(attr.first, attr.second); + dst->setAttr(attr.first, attr.second); } else { src->emitError("Denied unregistered attribute was found: " + attr_name); return failure(); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 2d2e7197d0e..18041f98c07 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -4637,11 +4637,11 @@ class ConvertInfeedDequeueTupleOp if (sharding_proto.type() == ::xla::OpSharding::TUPLE) { *sharding_proto.add_tuple_shardings() = ::xla::sharding_builder::AssignDevice(0); - data_and_token.setAttr( + data_and_token->setAttr( kShardingAttr, rewriter.getStringAttr(sharding_proto.SerializeAsString())); } else { - data_and_token.setAttr(kShardingAttr, op._XlaShardingAttr()); + data_and_token->setAttr(kShardingAttr, op._XlaShardingAttr()); } } @@ -5157,7 +5157,7 @@ class ConvertXlaShardingOp : public OpRewritePattern { /*call_target_name=*/rewriter.getStringAttr("Sharding"), /*has_side_effect=*/rewriter.getBoolAttr(false), /*backend_config=*/rewriter.getStringAttr("")); - custom_call.setAttr(kShardingAttr, op._XlaShardingAttr()); + custom_call->setAttr(kShardingAttr, op._XlaShardingAttr()); rewriter.replaceOp(op, custom_call.getResult(0)); return success();