Partially add derived attribute materialize functions

Avoids warning when building dialects and in preparation for using these on
export. Deferred shape materialization pending new shape attribute.

PiperOrigin-RevId: 307925585
Change-Id: I3a764915ab514206e5fe8b1c98c9419f8cdd8b2b
This commit is contained in:
Jacques Pienaar 2020-04-22 16:20:26 -07:00 committed by TensorFlower Gardener
parent 49e59a8cad
commit b5c3ed7e9b
3 changed files with 43 additions and 11 deletions

View File

@ -103,8 +103,8 @@ def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
// Derived shape attribute class.
//===----------------------------------------------------------------------===//
class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
class DerivedTFLiteTypeAttr<code body> :
DerivedAttr<"tflite::TensorType", body>;
class DerivedTFLiteTypeAttr<code body, code convert> :
DerivedAttr<"tflite::TensorType", body, convert>;
// TFL Runtime op trait predicate.
class TFL_RuntimePredOpTrait<string desc, Pred pred> :
@ -572,6 +572,8 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}], [{
TypeAttr::get(getResult().getType().cast<TensorType>().getElementType())
}]>;
}
@ -600,6 +602,8 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}], [{
TypeAttr::get(getResult().getType().cast<TensorType>().getElementType())
}]>;
}
@ -3111,6 +3115,8 @@ in the unique output `y`. In other words:
return getResult(1).getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}], [{
TypeAttr::get(getResult(1).getType().cast<TensorType>().getElementType())
}]>;
let hasOptions = 1;

View File

@ -233,7 +233,8 @@ def TF_ConvnetDataFormatAttr : StringBasedAttr<
class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
"size_t",
"auto range = getODSOperands(" # idx # ");\n"
"return std::distance(range.begin(), range.end());">;
"return std::distance(range.begin(), range.end());",
[{ $_builder.getI64IntegerAttr($_self) }]>;
// A derived attribute that returns the element type of `idx`-th ODS-declared
// operand. If the `idx`-th operand is a variadic operand, then this attribute
@ -251,7 +252,16 @@ class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
"mlir::OperandElementTypeRange",
"auto values = getODSOperands(" # idx # ");\n"
"return {mlir::OperandElementTypeIterator(values.begin()), "
"mlir::OperandElementTypeIterator(values.end())};"
"mlir::OperandElementTypeIterator(values.end())};",
[{
ArrayAttr::get(
[&]() {
llvm::SmallVector<Attribute, 4> ret;
for (auto t : $_self)
ret.push_back(TypeAttr::get(t));
return ret;
}(), $_ctx)
}]
>;
// A derived attribute that returns the shapes of the tensors in the actual
@ -262,7 +272,9 @@ class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
"mlir::TF::OperandShapeRange",
"auto values = getODSOperands(" # idx # ");\n"
"return {mlir::TF::OperandShapeIterator(values.begin()), "
"mlir::TF::OperandShapeIterator(values.end())};"
"mlir::TF::OperandShapeIterator(values.end())};",
// TODO(jpienaar): Update post TensorShapeAttr landing.
[{ nullptr }]
>;
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
@ -270,7 +282,8 @@ class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
"size_t",
"auto range = getODSResults(" # idx # ");\n"
"return std::distance(range.begin(), range.end());">;
"return std::distance(range.begin(), range.end());",
[{ $_builder.getI64IntegerAttr($_self) }]>;
// A derived attribute that returns the element type of `idx`-th ODS-declared
// result. If the `idx`-th result is a variadic result, then this attribute
@ -288,7 +301,16 @@ class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
"mlir::ResultElementTypeRange",
"auto values = getODSResults(" # idx # ");\n"
"return {mlir::ResultElementTypeIterator(values.begin()), "
"mlir::ResultElementTypeIterator(values.end())};"
"mlir::ResultElementTypeIterator(values.end())};",
[{
ArrayAttr::get(
[&]() {
llvm::SmallVector<Attribute, 4> ret;
for (auto t : $_self)
ret.push_back(TypeAttr::get(t));
return ret;
}(), $_ctx)
}]
>;
// A derived attribute that returns the shapes of the tensors in the actual
@ -299,12 +321,15 @@ class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
"mlir::TF::ResultShapeRange",
"auto values = getODSResults(" # idx # ");\n"
"return {mlir::TF::ResultShapeIterator(values.begin()), "
"mlir::TF::ResultShapeIterator(values.end())};"
"mlir::TF::ResultShapeIterator(values.end())};",
// TODO(jpienaar): Update post TensorShapeAttr landing.
[{ nullptr }]
>;
// A derived attribute that returns the shape of the first result type.
def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
"return (*getOperation()->result_type_begin()).cast<ShapedType>();">;
"return (*getOperation()->result_type_begin()).cast<ShapedType>();",
[{ TypeAttr::get($_self) }]>;
// A derived attribute that returns the element type of the tensor held by a
// named resource-type operand or result.
@ -315,7 +340,6 @@ class TF_DerivedOperandOrResultHandleTypeAttr<string name> : DerivedTypeAttr<
"assert(!resource_type.getSubtypes().empty() && \"unknown type\");\n"
"return mlir::getElementTypeOrSelf(*resource_type.getSubtypes().begin());">;
// A derived attribute that returns the shape of the tensor held by a named
// resource-type operand or result.
class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr<
@ -324,7 +348,8 @@ class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr<
" mlir::getElementTypeOrSelf(this->" # name # "())\n"
" .cast<TF::ResourceType>();\n"
"assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n"
"return resource_type.getSubtypes().begin()->cast<ShapedType>();">;
"return resource_type.getSubtypes().begin()->cast<ShapedType>();",
[{ TypeAttr::get($_self) }]>;
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
let returnType = "Type";

View File

@ -28,6 +28,7 @@ limitations under the License.
#define TF_OPS
include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td"
include "mlir/Interfaces/CallInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpBase.td"