diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index e9192388070..6a631b1433d 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -1020,7 +1020,7 @@ Optional> Translator::BuildOperator( if (!inst->getMutableAttrDict().getAttrs().empty()) { os << " {"; bool first = true; - for (auto& named_attr : inst->getMutableAttrDict().getDictionary()) { + for (auto& named_attr : inst->getAttrDictionary()) { os << (!first ? ", " : ""); first = false; named_attr.first.print(os); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 3a4e9e5985e..edba135e0f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -984,7 +984,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, LogicalResult ConstOp::inferReturnTypes( MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { for (NamedAttribute named_attr : attributes) { if (named_attr.first.strref() != "value") continue; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 789088bd585..efe82c4268b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -323,8 +323,8 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti, Operation* op = infer_ti.getOperation(); SmallVector inferred; LogicalResult res = infer_ti.inferReturnTypes( - op->getContext(), op->getLoc(), op->getOperands(), op->getAttrs(), - op->getRegions(), inferred); + op->getContext(), op->getLoc(), op->getOperands(), + op->getAttrDictionary(), op->getRegions(), inferred); if (failed(res)) { op->emitOpError("failed to refine type as inference failed"); return false; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 3bb1446213b..a613ce1f920 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -112,6 +112,7 @@ static inline absl::string_view StringRefToView(llvm::StringRef ref) { } namespace tensorflow { +using mlir::NamedAttrList; using mlir::TensorType; using mlir::TF::VarHandleOp; using mlir::tf_saved_model::GlobalTensorOp; @@ -309,9 +310,9 @@ class ImporterBase { // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar}, // {base_name.k2 : rfc}}. - Status ConvertFunctionCallAttribute( - const std::string& base_name, const AttrValue& value, - llvm::SmallVector* attributes); + Status ConvertFunctionCallAttribute(const std::string& base_name, + const AttrValue& value, + NamedAttrList* attributes); // Helper to create either a tf_executor operation or a TF operation wrapped // in an island. When convert_to_legacy_call is true, converts the operation @@ -1092,9 +1093,9 @@ StatusOr ImporterBase::ConvertSubtypes( return subtypes; } -Status ImporterBase::ConvertFunctionCallAttribute( - const std::string& base_name, const AttrValue& value, - llvm::SmallVector* attributes) { +Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name, + const AttrValue& value, + NamedAttrList* attributes) { TF_ASSIGN_OR_RETURN(auto func_attr, ConvertFunctionCallName(value.func().name())); attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc index bc6842a617e..5322668aa2e 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc @@ -97,16 +97,12 @@ static Type GetBroadcastType(Type x, Type y, Type element_type, LogicalResult InferBroadcastBinaryOpReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, Type element_type, + DictionaryAttr attributes, Type element_type, SmallVectorImpl& inferedReturnShapes) { // Find broadcast_dimensions. - DenseIntElementsAttr broadcast_dimensions; - for (auto attr : attributes) { - if (attr.first == "broadcast_dimensions") { - broadcast_dimensions = attr.second.dyn_cast(); - break; - } - } + DenseIntElementsAttr broadcast_dimensions = + attributes.get("broadcast_dimensions") + .dyn_cast_or_null(); ShapedType lhs_type = operands[0].getType().dyn_cast(); ShapedType rhs_type = operands[1].getType().dyn_cast(); @@ -168,7 +164,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( LogicalResult BroadcastComplexOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { ShapedType lhs_type = operands[0].getType().dyn_cast(); if (!lhs_type) { @@ -191,7 +187,7 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( LogicalResult BroadcastCompareOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { Type element_type = IntegerType::get(1, context); return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, @@ -211,7 +207,7 @@ LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( #define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \ LogicalResult Op::inferReturnTypeComponents( \ MLIRContext* context, Optional location, ValueRange operands, \ - ArrayRef attributes, RegionRange regions, \ + DictionaryAttr attributes, RegionRange regions, \ SmallVectorImpl& inferedReturnShapes) { \ return InferBroadcastBinaryOpReturnTypeComponents( \ context, location, operands, attributes, /*element_type=*/nullptr, \ diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index cb7372a762c..874f4a1587e 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -1240,7 +1240,7 @@ static LogicalResult Verify(SelectOp op) { // the return type based on operand type. LogicalResult SelectOp::inferReturnTypes( MLIRContext*, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { auto x_type = operands[1].getType(); auto y_type = operands[2].getType(); diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index c2eef4a90b2..917a50f74ea 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -104,8 +104,7 @@ class HLO_UnaryElementwiseOp traits, let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional location, - ValueRange operands, ArrayRef attributes, - RegionRange regions, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { return failure(); } @@ -254,7 +253,7 @@ class HLO_BinaryElementwiseOp traits> : let extraClassDeclaration = [{ static LogicalResult inferReturnTypeComponents( MLIRContext* context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferedReturnShapes) { return failure(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc b/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc index 8976bd5b7d2..71441656c08 100644 --- a/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/test_infer_shaped_type_pass.cc @@ -38,7 +38,8 @@ struct InferReturnTypeComponentsPattern : public RewritePattern { SmallVector components; if (failed(defining_op_int.inferReturnTypeComponents( op->getContext(), op->getLoc(), defining_op->getOperands(), - defining_op->getAttrs(), defining_op->getRegions(), components))) { + defining_op->getAttrDictionary(), defining_op->getRegions(), + components))) { return failure(); }