PiperOrigin-RevId: 310785048
Change-Id: I2d93edb9c66c4262c985fa088f88ad22e3e6cada
This commit is contained in:
A. Unique TensorFlower 2020-05-10 04:47:55 -07:00 committed by TensorFlower Gardener
parent 7e08dcef2c
commit d716200266
8 changed files with 23 additions and 26 deletions

View File

@ -1020,7 +1020,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
if (!inst->getMutableAttrDict().getAttrs().empty()) {
os << " {";
bool first = true;
for (auto& named_attr : inst->getMutableAttrDict().getDictionary()) {
for (auto& named_attr : inst->getAttrDictionary()) {
os << (!first ? ", " : "");
first = false;
named_attr.first.print(os);

View File

@ -984,7 +984,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, Type type,
LogicalResult ConstOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
for (NamedAttribute named_attr : attributes) {
if (named_attr.first.strref() != "value") continue;

View File

@ -323,8 +323,8 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
Operation* op = infer_ti.getOperation();
SmallVector<Type, 4> inferred;
LogicalResult res = infer_ti.inferReturnTypes(
op->getContext(), op->getLoc(), op->getOperands(), op->getAttrs(),
op->getRegions(), inferred);
op->getContext(), op->getLoc(), op->getOperands(),
op->getAttrDictionary(), op->getRegions(), inferred);
if (failed(res)) {
op->emitOpError("failed to refine type as inference failed");
return false;

View File

@ -112,6 +112,7 @@ static inline absl::string_view StringRefToView(llvm::StringRef ref) {
}
namespace tensorflow {
using mlir::NamedAttrList;
using mlir::TensorType;
using mlir::TF::VarHandleOp;
using mlir::tf_saved_model::GlobalTensorOp;
@ -309,9 +310,9 @@ class ImporterBase {
// AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
// a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
// {base_name.k2 : rfc}}.
Status ConvertFunctionCallAttribute(
const std::string& base_name, const AttrValue& value,
llvm::SmallVector<mlir::NamedAttribute, 4>* attributes);
Status ConvertFunctionCallAttribute(const std::string& base_name,
const AttrValue& value,
NamedAttrList* attributes);
// Helper to create either a tf_executor operation or a TF operation wrapped
// in an island. When convert_to_legacy_call is true, converts the operation
@ -1092,9 +1093,9 @@ StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
return subtypes;
}
Status ImporterBase::ConvertFunctionCallAttribute(
const std::string& base_name, const AttrValue& value,
llvm::SmallVector<mlir::NamedAttribute, 4>* attributes) {
Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
const AttrValue& value,
NamedAttrList* attributes) {
TF_ASSIGN_OR_RETURN(auto func_attr,
ConvertFunctionCallName(value.func().name()));
attributes->push_back(builder_.getNamedAttr(base_name, func_attr));

View File

@ -97,16 +97,12 @@ static Type GetBroadcastType(Type x, Type y, Type element_type,
LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, Type element_type,
DictionaryAttr attributes, Type element_type,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
// Find broadcast_dimensions.
DenseIntElementsAttr broadcast_dimensions;
for (auto attr : attributes) {
if (attr.first == "broadcast_dimensions") {
broadcast_dimensions = attr.second.dyn_cast<DenseIntElementsAttr>();
break;
}
}
DenseIntElementsAttr broadcast_dimensions =
attributes.get("broadcast_dimensions")
.dyn_cast_or_null<DenseIntElementsAttr>();
ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
ShapedType rhs_type = operands[1].getType().dyn_cast<ShapedType>();
@ -168,7 +164,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
if (!lhs_type) {
@ -191,7 +187,7 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
Type element_type = IntegerType::get(1, context);
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
@ -211,7 +207,7 @@ LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
#define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \
LogicalResult Op::inferReturnTypeComponents( \
MLIRContext* context, Optional<Location> location, ValueRange operands, \
ArrayRef<NamedAttribute> attributes, RegionRange regions, \
DictionaryAttr attributes, RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \
return InferBroadcastBinaryOpReturnTypeComponents( \
context, location, operands, attributes, /*element_type=*/nullptr, \

View File

@ -1240,7 +1240,7 @@ static LogicalResult Verify(SelectOp op) {
// the return type based on operand type.
LogicalResult SelectOp::inferReturnTypes(
MLIRContext*, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto x_type = operands[1].getType();
auto y_type = operands[2].getType();

View File

@ -104,8 +104,7 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
let extraClassDeclaration = [{
static LogicalResult inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location,
ValueRange operands, ArrayRef<NamedAttribute> attributes,
RegionRange regions,
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
return failure();
}
@ -254,7 +253,7 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
let extraClassDeclaration = [{
static LogicalResult inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
return failure();
}

View File

@ -38,7 +38,8 @@ struct InferReturnTypeComponentsPattern : public RewritePattern {
SmallVector<ShapedTypeComponents, 4> components;
if (failed(defining_op_int.inferReturnTypeComponents(
op->getContext(), op->getLoc(), defining_op->getOperands(),
defining_op->getAttrs(), defining_op->getRegions(), components))) {
defining_op->getAttrDictionary(), defining_op->getRegions(),
components))) {
return failure();
}