PiperOrigin-RevId: 310785048
Change-Id: I2d93edb9c66c4262c985fa088f88ad22e3e6cada
This commit is contained in:
A. Unique TensorFlower 2020-05-10 04:47:55 -07:00 committed by TensorFlower Gardener
parent 7e08dcef2c
commit d716200266
8 changed files with 23 additions and 26 deletions

View File

@ -1020,7 +1020,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
if (!inst->getMutableAttrDict().getAttrs().empty()) { if (!inst->getMutableAttrDict().getAttrs().empty()) {
os << " {"; os << " {";
bool first = true; bool first = true;
for (auto& named_attr : inst->getMutableAttrDict().getDictionary()) { for (auto& named_attr : inst->getAttrDictionary()) {
os << (!first ? ", " : ""); os << (!first ? ", " : "");
first = false; first = false;
named_attr.first.print(os); named_attr.first.print(os);

View File

@ -984,7 +984,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, Type type,
LogicalResult ConstOp::inferReturnTypes( LogicalResult ConstOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands, MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) { SmallVectorImpl<Type> &inferredReturnTypes) {
for (NamedAttribute named_attr : attributes) { for (NamedAttribute named_attr : attributes) {
if (named_attr.first.strref() != "value") continue; if (named_attr.first.strref() != "value") continue;

View File

@ -323,8 +323,8 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti,
Operation* op = infer_ti.getOperation(); Operation* op = infer_ti.getOperation();
SmallVector<Type, 4> inferred; SmallVector<Type, 4> inferred;
LogicalResult res = infer_ti.inferReturnTypes( LogicalResult res = infer_ti.inferReturnTypes(
op->getContext(), op->getLoc(), op->getOperands(), op->getAttrs(), op->getContext(), op->getLoc(), op->getOperands(),
op->getRegions(), inferred); op->getAttrDictionary(), op->getRegions(), inferred);
if (failed(res)) { if (failed(res)) {
op->emitOpError("failed to refine type as inference failed"); op->emitOpError("failed to refine type as inference failed");
return false; return false;

View File

@ -112,6 +112,7 @@ static inline absl::string_view StringRefToView(llvm::StringRef ref) {
} }
namespace tensorflow { namespace tensorflow {
using mlir::NamedAttrList;
using mlir::TensorType; using mlir::TensorType;
using mlir::TF::VarHandleOp; using mlir::TF::VarHandleOp;
using mlir::tf_saved_model::GlobalTensorOp; using mlir::tf_saved_model::GlobalTensorOp;
@ -309,9 +310,9 @@ class ImporterBase {
// AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to
// a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar}, // a list of MLIR Attributes: [{base_name : foo}, {base_name.k1 : bar},
// {base_name.k2 : rfc}}. // {base_name.k2 : rfc}}.
Status ConvertFunctionCallAttribute( Status ConvertFunctionCallAttribute(const std::string& base_name,
const std::string& base_name, const AttrValue& value, const AttrValue& value,
llvm::SmallVector<mlir::NamedAttribute, 4>* attributes); NamedAttrList* attributes);
// Helper to create either a tf_executor operation or a TF operation wrapped // Helper to create either a tf_executor operation or a TF operation wrapped
// in an island. When convert_to_legacy_call is true, converts the operation // in an island. When convert_to_legacy_call is true, converts the operation
@ -1092,9 +1093,9 @@ StatusOr<ImporterBase::ElementSubtypes> ImporterBase::ConvertSubtypes(
return subtypes; return subtypes;
} }
Status ImporterBase::ConvertFunctionCallAttribute( Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
const std::string& base_name, const AttrValue& value, const AttrValue& value,
llvm::SmallVector<mlir::NamedAttribute, 4>* attributes) { NamedAttrList* attributes) {
TF_ASSIGN_OR_RETURN(auto func_attr, TF_ASSIGN_OR_RETURN(auto func_attr,
ConvertFunctionCallName(value.func().name())); ConvertFunctionCallName(value.func().name()));
attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); attributes->push_back(builder_.getNamedAttr(base_name, func_attr));

View File

@ -97,16 +97,12 @@ static Type GetBroadcastType(Type x, Type y, Type element_type,
LogicalResult InferBroadcastBinaryOpReturnTypeComponents( LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands, MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, Type element_type, DictionaryAttr attributes, Type element_type,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
// Find broadcast_dimensions. // Find broadcast_dimensions.
DenseIntElementsAttr broadcast_dimensions; DenseIntElementsAttr broadcast_dimensions =
for (auto attr : attributes) { attributes.get("broadcast_dimensions")
if (attr.first == "broadcast_dimensions") { .dyn_cast_or_null<DenseIntElementsAttr>();
broadcast_dimensions = attr.second.dyn_cast<DenseIntElementsAttr>();
break;
}
}
ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>(); ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
ShapedType rhs_type = operands[1].getType().dyn_cast<ShapedType>(); ShapedType rhs_type = operands[1].getType().dyn_cast<ShapedType>();
@ -168,7 +164,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
LogicalResult BroadcastComplexOp::inferReturnTypeComponents( LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands, MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>(); ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
if (!lhs_type) { if (!lhs_type) {
@ -191,7 +187,7 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
LogicalResult BroadcastCompareOp::inferReturnTypeComponents( LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands, MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
Type element_type = IntegerType::get(1, context); Type element_type = IntegerType::get(1, context);
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
@ -211,7 +207,7 @@ LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
#define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \ #define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \
LogicalResult Op::inferReturnTypeComponents( \ LogicalResult Op::inferReturnTypeComponents( \
MLIRContext* context, Optional<Location> location, ValueRange operands, \ MLIRContext* context, Optional<Location> location, ValueRange operands, \
ArrayRef<NamedAttribute> attributes, RegionRange regions, \ DictionaryAttr attributes, RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \ SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \
return InferBroadcastBinaryOpReturnTypeComponents( \ return InferBroadcastBinaryOpReturnTypeComponents( \
context, location, operands, attributes, /*element_type=*/nullptr, \ context, location, operands, attributes, /*element_type=*/nullptr, \

View File

@ -1240,7 +1240,7 @@ static LogicalResult Verify(SelectOp op) {
// the return type based on operand type. // the return type based on operand type.
LogicalResult SelectOp::inferReturnTypes( LogicalResult SelectOp::inferReturnTypes(
MLIRContext*, Optional<Location> location, ValueRange operands, MLIRContext*, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type>& inferredReturnTypes) { SmallVectorImpl<Type>& inferredReturnTypes) {
auto x_type = operands[1].getType(); auto x_type = operands[1].getType();
auto y_type = operands[2].getType(); auto y_type = operands[2].getType();

View File

@ -104,8 +104,7 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static LogicalResult inferReturnTypeComponents( static LogicalResult inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, MLIRContext* context, Optional<Location> location,
ValueRange operands, ArrayRef<NamedAttribute> attributes, ValueRange operands, DictionaryAttr attributes, RegionRange regions,
RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
return failure(); return failure();
} }
@ -254,7 +253,7 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static LogicalResult inferReturnTypeComponents( static LogicalResult inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands, MLIRContext* context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
return failure(); return failure();
} }

View File

@ -38,7 +38,8 @@ struct InferReturnTypeComponentsPattern : public RewritePattern {
SmallVector<ShapedTypeComponents, 4> components; SmallVector<ShapedTypeComponents, 4> components;
if (failed(defining_op_int.inferReturnTypeComponents( if (failed(defining_op_int.inferReturnTypeComponents(
op->getContext(), op->getLoc(), defining_op->getOperands(), op->getContext(), op->getLoc(), defining_op->getOperands(),
defining_op->getAttrs(), defining_op->getRegions(), components))) { defining_op->getAttrDictionary(), defining_op->getRegions(),
components))) {
return failure(); return failure();
} }