Partially add derived attribute materialize functions

Avoids warning when building dialects and in preparation for using these on
export. Deferred shape materialization pending new shape attribute.

PiperOrigin-RevId: 307925585
Change-Id: I3a764915ab514206e5fe8b1c98c9419f8cdd8b2b
This commit is contained in:
Jacques Pienaar 2020-04-22 16:20:26 -07:00 committed by TensorFlower Gardener
parent 49e59a8cad
commit b5c3ed7e9b
3 changed files with 43 additions and 11 deletions

View File

@ -103,8 +103,8 @@ def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
// Derived shape attribute class. // Derived shape attribute class.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>; class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
class DerivedTFLiteTypeAttr<code body> : class DerivedTFLiteTypeAttr<code body, code convert> :
DerivedAttr<"tflite::TensorType", body>; DerivedAttr<"tflite::TensorType", body, convert>;
// TFL Runtime op trait predicate. // TFL Runtime op trait predicate.
class TFL_RuntimePredOpTrait<string desc, Pred pred> : class TFL_RuntimePredOpTrait<string desc, Pred pred> :
@ -572,6 +572,8 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
return getResult().getType().cast<TensorType>().getElementType(). return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 : cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32; tflite::TensorType_INT32;
}], [{
TypeAttr::get(getResult().getType().cast<TensorType>().getElementType())
}]>; }]>;
} }
@ -600,6 +602,8 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
return getResult().getType().cast<TensorType>().getElementType(). return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 : cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32; tflite::TensorType_INT32;
}], [{
TypeAttr::get(getResult().getType().cast<TensorType>().getElementType())
}]>; }]>;
} }
@ -3111,6 +3115,8 @@ in the unique output `y`. In other words:
return getResult(1).getType().cast<TensorType>().getElementType(). return getResult(1).getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 : cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32; tflite::TensorType_INT32;
}], [{
TypeAttr::get(getResult(1).getType().cast<TensorType>().getElementType())
}]>; }]>;
let hasOptions = 1; let hasOptions = 1;

View File

@ -233,7 +233,8 @@ def TF_ConvnetDataFormatAttr : StringBasedAttr<
class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr< class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
"size_t", "size_t",
"auto range = getODSOperands(" # idx # ");\n" "auto range = getODSOperands(" # idx # ");\n"
"return std::distance(range.begin(), range.end());">; "return std::distance(range.begin(), range.end());",
[{ $_builder.getI64IntegerAttr($_self) }]>;
// A derived attribute that returns the element type of `idx`-th ODS-declared // A derived attribute that returns the element type of `idx`-th ODS-declared
// operand. If the `idx`-th operand is a variadic operand, then this attribute // operand. If the `idx`-th operand is a variadic operand, then this attribute
@ -251,7 +252,16 @@ class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
"mlir::OperandElementTypeRange", "mlir::OperandElementTypeRange",
"auto values = getODSOperands(" # idx # ");\n" "auto values = getODSOperands(" # idx # ");\n"
"return {mlir::OperandElementTypeIterator(values.begin()), " "return {mlir::OperandElementTypeIterator(values.begin()), "
"mlir::OperandElementTypeIterator(values.end())};" "mlir::OperandElementTypeIterator(values.end())};",
[{
ArrayAttr::get(
[&]() {
llvm::SmallVector<Attribute, 4> ret;
for (auto t : $_self)
ret.push_back(TypeAttr::get(t));
return ret;
}(), $_ctx)
}]
>; >;
// A derived attribute that returns the shapes of the tensors in the actual // A derived attribute that returns the shapes of the tensors in the actual
@ -262,7 +272,9 @@ class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
"mlir::TF::OperandShapeRange", "mlir::TF::OperandShapeRange",
"auto values = getODSOperands(" # idx # ");\n" "auto values = getODSOperands(" # idx # ");\n"
"return {mlir::TF::OperandShapeIterator(values.begin()), " "return {mlir::TF::OperandShapeIterator(values.begin()), "
"mlir::TF::OperandShapeIterator(values.end())};" "mlir::TF::OperandShapeIterator(values.end())};",
// TODO(jpienaar): Update post TensorShapeAttr landing.
[{ nullptr }]
>; >;
// A derived attribute that returns the size of `idx`-th ODS-declared variadic // A derived attribute that returns the size of `idx`-th ODS-declared variadic
@ -270,7 +282,8 @@ class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
class TF_DerivedResultSizeAttr<int idx> : DerivedAttr< class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
"size_t", "size_t",
"auto range = getODSResults(" # idx # ");\n" "auto range = getODSResults(" # idx # ");\n"
"return std::distance(range.begin(), range.end());">; "return std::distance(range.begin(), range.end());",
[{ $_builder.getI64IntegerAttr($_self) }]>;
// A derived attribute that returns the element type of `idx`-th ODS-declared // A derived attribute that returns the element type of `idx`-th ODS-declared
// result. If the `idx`-th result is a variadic result, then this attribute // result. If the `idx`-th result is a variadic result, then this attribute
@ -288,7 +301,16 @@ class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
"mlir::ResultElementTypeRange", "mlir::ResultElementTypeRange",
"auto values = getODSResults(" # idx # ");\n" "auto values = getODSResults(" # idx # ");\n"
"return {mlir::ResultElementTypeIterator(values.begin()), " "return {mlir::ResultElementTypeIterator(values.begin()), "
"mlir::ResultElementTypeIterator(values.end())};" "mlir::ResultElementTypeIterator(values.end())};",
[{
ArrayAttr::get(
[&]() {
llvm::SmallVector<Attribute, 4> ret;
for (auto t : $_self)
ret.push_back(TypeAttr::get(t));
return ret;
}(), $_ctx)
}]
>; >;
// A derived attribute that returns the shapes of the tensors in the actual // A derived attribute that returns the shapes of the tensors in the actual
@ -299,12 +321,15 @@ class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
"mlir::TF::ResultShapeRange", "mlir::TF::ResultShapeRange",
"auto values = getODSResults(" # idx # ");\n" "auto values = getODSResults(" # idx # ");\n"
"return {mlir::TF::ResultShapeIterator(values.begin()), " "return {mlir::TF::ResultShapeIterator(values.begin()), "
"mlir::TF::ResultShapeIterator(values.end())};" "mlir::TF::ResultShapeIterator(values.end())};",
// TODO(jpienaar): Update post TensorShapeAttr landing.
[{ nullptr }]
>; >;
// A derived attribute that returns the shape of the first result type. // A derived attribute that returns the shape of the first result type.
def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
"return (*getOperation()->result_type_begin()).cast<ShapedType>();">; "return (*getOperation()->result_type_begin()).cast<ShapedType>();",
[{ TypeAttr::get($_self) }]>;
// A derived attribute that returns the element type of the tensor held by a // A derived attribute that returns the element type of the tensor held by a
// named resource-type operand or result. // named resource-type operand or result.
@ -315,7 +340,6 @@ class TF_DerivedOperandOrResultHandleTypeAttr<string name> : DerivedTypeAttr<
"assert(!resource_type.getSubtypes().empty() && \"unknown type\");\n" "assert(!resource_type.getSubtypes().empty() && \"unknown type\");\n"
"return mlir::getElementTypeOrSelf(*resource_type.getSubtypes().begin());">; "return mlir::getElementTypeOrSelf(*resource_type.getSubtypes().begin());">;
// A derived attribute that returns the shape of the tensor held by a named // A derived attribute that returns the shape of the tensor held by a named
// resource-type operand or result. // resource-type operand or result.
class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr< class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr<
@ -324,7 +348,8 @@ class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr<
" mlir::getElementTypeOrSelf(this->" # name # "())\n" " mlir::getElementTypeOrSelf(this->" # name # "())\n"
" .cast<TF::ResourceType>();\n" " .cast<TF::ResourceType>();\n"
"assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n" "assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n"
"return resource_type.getSubtypes().begin()->cast<ShapedType>();">; "return resource_type.getSubtypes().begin()->cast<ShapedType>();",
[{ TypeAttr::get($_self) }]>;
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
let returnType = "Type"; let returnType = "Type";

View File

@ -28,6 +28,7 @@ limitations under the License.
#define TF_OPS #define TF_OPS
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"