Add Unique Op to MLIR converter.
PiperOrigin-RevId: 258271679
This commit is contained in:
parent
2010ed0f14
commit
bb797ba263
@ -172,6 +172,7 @@ cc_library(
|
||||
":tensorflow_lite_ops_inc_gen",
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:Dialect",
|
||||
@ -303,7 +304,7 @@ genrule(
|
||||
srcs = [
|
||||
"@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td",
|
||||
"@local_config_mlir//:include/mlir/IR/OpBase.td",
|
||||
"//tensorflow/compiler/mlir/lite:ir/tfl_ops.td",
|
||||
":ir/tfl_ops.td",
|
||||
],
|
||||
outs = [
|
||||
"utils/generated_op_quant_spec_getters.inc",
|
||||
@ -344,7 +345,7 @@ genrule(
|
||||
srcs = [
|
||||
"@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td",
|
||||
"@local_config_mlir//:include/mlir/IR/OpBase.td",
|
||||
"//tensorflow/compiler/mlir/lite:ir/tfl_ops.td",
|
||||
":ir/tfl_ops.td",
|
||||
],
|
||||
outs = [
|
||||
"operator_writers.inc",
|
||||
|
@ -35,6 +35,16 @@ static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
|
||||
.Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT);
|
||||
}
|
||||
|
||||
static tflite::TensorType ConvertDerivedTFLiteTypeAttrForOptionWriter(
|
||||
tflite::TensorType type, flatbuffers::FlatBufferBuilder* builder) {
|
||||
if (type == tflite::TensorType_INT64) {
|
||||
return tflite::TensorType_INT64;
|
||||
} else if (type == tflite::TensorType_INT32) {
|
||||
return tflite::TensorType_INT32;
|
||||
}
|
||||
llvm_unreachable("invalid type in conversion.");
|
||||
}
|
||||
|
||||
static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
|
||||
llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
|
||||
return llvm::StringSwitch<tflite::Padding>(str)
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
@ -148,6 +148,7 @@ def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
|
||||
// Derived shape attribute class.
|
||||
//===----------------------------------------------------------------------===//
|
||||
class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
|
||||
class DerivedTFLiteTypeAttr<code body> : DerivedAttr<"tflite::TensorType", body>;
|
||||
|
||||
def TFL_Int32Or64 : IntOfWidths<[32, 64]>;
|
||||
|
||||
@ -2059,6 +2060,35 @@ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_UniqueOp: TFL_Op<"unique", [NoSideEffect]> {
|
||||
let summary = "Unique Op.";
|
||||
|
||||
let description = [{
|
||||
This operation returns a tensor `y` containing all of the unique elements of `x`
|
||||
sorted in the same order that they occur in `x`. This operation also returns a
|
||||
tensor `idx` the same size as `x` that contains the index of each value of `x`
|
||||
in the unique output `y`. In other words:
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
// TODO: add uint8 support after quantize support.
|
||||
TensorOf<[I8, I16, I32, I64, F32]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[I8, I16, I32, I64, F32]>:$output,
|
||||
TensorOf<[I32, I64]>:$idx
|
||||
);
|
||||
|
||||
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
|
||||
return getResult(1)->getType().cast<TensorType>().getElementType().
|
||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||
tflite::TensorType_INT32;
|
||||
}]>;
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Quantization ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -882,3 +882,21 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
|
||||
// CHECK-LABEL: cast
|
||||
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
|
||||
}
|
||||
|
||||
func @unique(%arg0: tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi32>) {
|
||||
%0, %1 = "tf.Unique"(%arg0) : (tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi32>)
|
||||
return %0, %1 : tensor<?xf32> , tensor<?xi32>
|
||||
|
||||
// CHECK-LABEL: unique
|
||||
// CHECK: %0:2 = "tfl.unique"(%arg0) : (tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi32>)
|
||||
// CHECK: %0
|
||||
}
|
||||
|
||||
func @unique64(%arg0: tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi64>) {
|
||||
%0, %1 = "tf.Unique"(%arg0) : (tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||
return %0, %1 : tensor<?xf32> , tensor<?xi64>
|
||||
|
||||
// CHECK-LABEL: unique64
|
||||
// CHECK: %0:2 = "tfl.unique"(%arg0) : (tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi64>)
|
||||
// CHECK: %0
|
||||
}
|
||||
|
@ -270,3 +270,5 @@ def : Pat<
|
||||
(TFL_StridedSliceOp $input, $begin, $end, $strides,
|
||||
(convertIntAttrTo32Bit $begin_mask), (convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
|
||||
(convertIntAttrTo32Bit $new_axis_mask), (convertIntAttrTo32Bit $shrink_axis_mask))>;
|
||||
|
||||
def : Pat<(TF_UniqueOp $arg0),(TFL_UniqueOp $arg0)>;
|
||||
|
@ -2934,6 +2934,40 @@ Python Semantics.
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_UniqueOp : TF_Op<"Unique", [NoSideEffect]> {
|
||||
let summary = "Finds unique elements in a 1-D tensor.";
|
||||
|
||||
let description = [{
|
||||
This operation returns a tensor `y` containing all of the unique elements of `x`
|
||||
sorted in the same order that they occur in `x`. This operation also returns a
|
||||
tensor `idx` the same size as `x` that contains the index of each value of `x`
|
||||
in the unique output `y`. In other words:
|
||||
|
||||
`y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
|
||||
y, idx = unique(x)
|
||||
y ==> [1, 2, 4, 7, 8]
|
||||
idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$x
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$y,
|
||||
TF_I32OrI64Tensor:$idx
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedResultTypeAttr out_idx = TF_DerivedResultTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_UnpackOp : TF_Op<"Unpack", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
|
||||
|
@ -110,7 +110,7 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
|
||||
default:
|
||||
context->ReportError(
|
||||
context,
|
||||
"Unique index output array can only be Int32 or In64, requested: ",
|
||||
"Unique index output array can only be Int32 or In64, requested: %s",
|
||||
TfLiteTypeGetName(params->index_out_type));
|
||||
}
|
||||
return kTfLiteError;
|
||||
|
Loading…
x
Reference in New Issue
Block a user