Partially add derived attribute materialize functions
Avoids warning when building dialects and in preparation for using these on export. Deferred shape materialization pending new shape attribute. PiperOrigin-RevId: 307925585 Change-Id: I3a764915ab514206e5fe8b1c98c9419f8cdd8b2b
This commit is contained in:
parent
49e59a8cad
commit
b5c3ed7e9b
@ -103,8 +103,8 @@ def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
|
||||
// Derived shape attribute class.
|
||||
//===----------------------------------------------------------------------===//
|
||||
class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
|
||||
class DerivedTFLiteTypeAttr<code body> :
|
||||
DerivedAttr<"tflite::TensorType", body>;
|
||||
class DerivedTFLiteTypeAttr<code body, code convert> :
|
||||
DerivedAttr<"tflite::TensorType", body, convert>;
|
||||
|
||||
// TFL Runtime op trait predicate.
|
||||
class TFL_RuntimePredOpTrait<string desc, Pred pred> :
|
||||
@ -572,6 +572,8 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
|
||||
return getResult().getType().cast<TensorType>().getElementType().
|
||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||
tflite::TensorType_INT32;
|
||||
}], [{
|
||||
TypeAttr::get(getResult().getType().cast<TensorType>().getElementType())
|
||||
}]>;
|
||||
}
|
||||
|
||||
@ -600,6 +602,8 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
|
||||
return getResult().getType().cast<TensorType>().getElementType().
|
||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||
tflite::TensorType_INT32;
|
||||
}], [{
|
||||
TypeAttr::get(getResult().getType().cast<TensorType>().getElementType())
|
||||
}]>;
|
||||
}
|
||||
|
||||
@ -3111,6 +3115,8 @@ in the unique output `y`. In other words:
|
||||
return getResult(1).getType().cast<TensorType>().getElementType().
|
||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||
tflite::TensorType_INT32;
|
||||
}], [{
|
||||
TypeAttr::get(getResult(1).getType().cast<TensorType>().getElementType())
|
||||
}]>;
|
||||
|
||||
let hasOptions = 1;
|
||||
|
@ -233,7 +233,8 @@ def TF_ConvnetDataFormatAttr : StringBasedAttr<
|
||||
class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
|
||||
"size_t",
|
||||
"auto range = getODSOperands(" # idx # ");\n"
|
||||
"return std::distance(range.begin(), range.end());">;
|
||||
"return std::distance(range.begin(), range.end());",
|
||||
[{ $_builder.getI64IntegerAttr($_self) }]>;
|
||||
|
||||
// A derived attribute that returns the element type of `idx`-th ODS-declared
|
||||
// operand. If the `idx`-th operand is a variadic operand, then this attribute
|
||||
@ -251,7 +252,16 @@ class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
|
||||
"mlir::OperandElementTypeRange",
|
||||
"auto values = getODSOperands(" # idx # ");\n"
|
||||
"return {mlir::OperandElementTypeIterator(values.begin()), "
|
||||
"mlir::OperandElementTypeIterator(values.end())};"
|
||||
"mlir::OperandElementTypeIterator(values.end())};",
|
||||
[{
|
||||
ArrayAttr::get(
|
||||
[&]() {
|
||||
llvm::SmallVector<Attribute, 4> ret;
|
||||
for (auto t : $_self)
|
||||
ret.push_back(TypeAttr::get(t));
|
||||
return ret;
|
||||
}(), $_ctx)
|
||||
}]
|
||||
>;
|
||||
|
||||
// A derived attribute that returns the shapes of the tensors in the actual
|
||||
@ -262,7 +272,9 @@ class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
|
||||
"mlir::TF::OperandShapeRange",
|
||||
"auto values = getODSOperands(" # idx # ");\n"
|
||||
"return {mlir::TF::OperandShapeIterator(values.begin()), "
|
||||
"mlir::TF::OperandShapeIterator(values.end())};"
|
||||
"mlir::TF::OperandShapeIterator(values.end())};",
|
||||
// TODO(jpienaar): Update post TensorShapeAttr landing.
|
||||
[{ nullptr }]
|
||||
>;
|
||||
|
||||
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
|
||||
@ -270,7 +282,8 @@ class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
|
||||
class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
|
||||
"size_t",
|
||||
"auto range = getODSResults(" # idx # ");\n"
|
||||
"return std::distance(range.begin(), range.end());">;
|
||||
"return std::distance(range.begin(), range.end());",
|
||||
[{ $_builder.getI64IntegerAttr($_self) }]>;
|
||||
|
||||
// A derived attribute that returns the element type of `idx`-th ODS-declared
|
||||
// result. If the `idx`-th result is a variadic result, then this attribute
|
||||
@ -288,7 +301,16 @@ class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
|
||||
"mlir::ResultElementTypeRange",
|
||||
"auto values = getODSResults(" # idx # ");\n"
|
||||
"return {mlir::ResultElementTypeIterator(values.begin()), "
|
||||
"mlir::ResultElementTypeIterator(values.end())};"
|
||||
"mlir::ResultElementTypeIterator(values.end())};",
|
||||
[{
|
||||
ArrayAttr::get(
|
||||
[&]() {
|
||||
llvm::SmallVector<Attribute, 4> ret;
|
||||
for (auto t : $_self)
|
||||
ret.push_back(TypeAttr::get(t));
|
||||
return ret;
|
||||
}(), $_ctx)
|
||||
}]
|
||||
>;
|
||||
|
||||
// A derived attribute that returns the shapes of the tensors in the actual
|
||||
@ -299,12 +321,15 @@ class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
|
||||
"mlir::TF::ResultShapeRange",
|
||||
"auto values = getODSResults(" # idx # ");\n"
|
||||
"return {mlir::TF::ResultShapeIterator(values.begin()), "
|
||||
"mlir::TF::ResultShapeIterator(values.end())};"
|
||||
"mlir::TF::ResultShapeIterator(values.end())};",
|
||||
// TODO(jpienaar): Update post TensorShapeAttr landing.
|
||||
[{ nullptr }]
|
||||
>;
|
||||
|
||||
// A derived attribute that returns the shape of the first result type.
|
||||
def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
|
||||
"return (*getOperation()->result_type_begin()).cast<ShapedType>();">;
|
||||
"return (*getOperation()->result_type_begin()).cast<ShapedType>();",
|
||||
[{ TypeAttr::get($_self) }]>;
|
||||
|
||||
// A derived attribute that returns the element type of the tensor held by a
|
||||
// named resource-type operand or result.
|
||||
@ -315,7 +340,6 @@ class TF_DerivedOperandOrResultHandleTypeAttr<string name> : DerivedTypeAttr<
|
||||
"assert(!resource_type.getSubtypes().empty() && \"unknown type\");\n"
|
||||
"return mlir::getElementTypeOrSelf(*resource_type.getSubtypes().begin());">;
|
||||
|
||||
|
||||
// A derived attribute that returns the shape of the tensor held by a named
|
||||
// resource-type operand or result.
|
||||
class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr<
|
||||
@ -324,7 +348,8 @@ class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr<
|
||||
" mlir::getElementTypeOrSelf(this->" # name # "())\n"
|
||||
" .cast<TF::ResourceType>();\n"
|
||||
"assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n"
|
||||
"return resource_type.getSubtypes().begin()->cast<ShapedType>();">;
|
||||
"return resource_type.getSubtypes().begin()->cast<ShapedType>();",
|
||||
[{ TypeAttr::get($_self) }]>;
|
||||
|
||||
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
|
||||
let returnType = "Type";
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#define TF_OPS
|
||||
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
Loading…
Reference in New Issue
Block a user