Add Unique Op to MLIR converter.

PiperOrigin-RevId: 258271679
This commit is contained in:
Karim Nosir 2019-07-15 17:32:34 -07:00 committed by TensorFlower Gardener
parent 2010ed0f14
commit bb797ba263
8 changed files with 99 additions and 3 deletions

View File

@ -172,6 +172,7 @@ cc_library(
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/lite/schema:schema_fbs",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:Dialect",
@ -303,7 +304,7 @@ genrule(
srcs = [
"@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td",
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"//tensorflow/compiler/mlir/lite:ir/tfl_ops.td",
":ir/tfl_ops.td",
],
outs = [
"utils/generated_op_quant_spec_getters.inc",
@ -344,7 +345,7 @@ genrule(
srcs = [
"@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td",
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"//tensorflow/compiler/mlir/lite:ir/tfl_ops.td",
":ir/tfl_ops.td",
],
outs = [
"operator_writers.inc",

View File

@ -35,6 +35,16 @@ static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter(
.Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT);
}
static tflite::TensorType ConvertDerivedTFLiteTypeAttrForOptionWriter(
tflite::TensorType type, flatbuffers::FlatBufferBuilder* builder) {
if (type == tflite::TensorType_INT64) {
return tflite::TensorType_INT64;
} else if (type == tflite::TensorType_INT32) {
return tflite::TensorType_INT32;
}
llvm_unreachable("invalid type in conversion.");
}
static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter(
llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) {
return llvm::StringSwitch<tflite::Padding>(str)

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mlir {
namespace TFL {

View File

@ -148,6 +148,7 @@ def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
// Derived shape attribute class.
//===----------------------------------------------------------------------===//
class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
class DerivedTFLiteTypeAttr<code body> : DerivedAttr<"tflite::TensorType", body>;
def TFL_Int32Or64 : IntOfWidths<[32, 64]>;
@ -2059,6 +2060,35 @@ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
let hasOptions = 1;
}
def TFL_UniqueOp: TFL_Op<"unique", [NoSideEffect]> {
let summary = "Unique Op.";
let description = [{
This operation returns a tensor `y` containing all of the unique elements of `x`
sorted in the same order that they occur in `x`. This operation also returns a
tensor `idx` the same size as `x` that contains the index of each value of `x`
in the unique output `y`. In other words:
}];
let arguments = (ins
// TODO: add uint8 support after quantize support.
TensorOf<[I8, I16, I32, I64, F32]>:$input
);
let results = (outs
TensorOf<[I8, I16, I32, I64, F32]>:$output,
TensorOf<[I32, I64]>:$idx
);
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
return getResult(1)->getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
let hasOptions = 1;
}
//===----------------------------------------------------------------------===//
// Quantization ops.
//===----------------------------------------------------------------------===//

View File

@ -882,3 +882,21 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> {
// CHECK-LABEL: cast
// CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32>
}
func @unique(%arg0: tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi32>) {
%0, %1 = "tf.Unique"(%arg0) : (tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi32>)
return %0, %1 : tensor<?xf32> , tensor<?xi32>
// CHECK-LABEL: unique
// CHECK: %0:2 = "tfl.unique"(%arg0) : (tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi32>)
// CHECK: %0
}
func @unique64(%arg0: tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi64>) {
%0, %1 = "tf.Unique"(%arg0) : (tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi64>)
return %0, %1 : tensor<?xf32> , tensor<?xi64>
// CHECK-LABEL: unique64
// CHECK: %0:2 = "tfl.unique"(%arg0) : (tensor<5xf32>) -> (tensor<?xf32>, tensor<?xi64>)
// CHECK: %0
}

View File

@ -270,3 +270,5 @@ def : Pat<
(TFL_StridedSliceOp $input, $begin, $end, $strides,
(convertIntAttrTo32Bit $begin_mask), (convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask),
(convertIntAttrTo32Bit $new_axis_mask), (convertIntAttrTo32Bit $shrink_axis_mask))>;
def : Pat<(TF_UniqueOp $arg0),(TFL_UniqueOp $arg0)>;

View File

@ -2934,6 +2934,40 @@ Python Semantics.
let hasCanonicalizer = 1;
}
def TF_UniqueOp : TF_Op<"Unique", [NoSideEffect]> {
let summary = "Finds unique elements in a 1-D tensor.";
let description = [{
This operation returns a tensor `y` containing all of the unique elements of `x`
sorted in the same order that they occur in `x`. This operation also returns a
tensor `idx` the same size as `x` that contains the index of each value of `x`
in the unique output `y`. In other words:
`y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
For example:
```
# tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
y, idx = unique(x)
y ==> [1, 2, 4, 7, 8]
idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
```
}];
let arguments = (ins
TF_Tensor:$x
);
let results = (outs
TF_Tensor:$y,
TF_I32OrI64Tensor:$idx
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultTypeAttr out_idx = TF_DerivedResultTypeAttr<1>;
}
def TF_UnpackOp : TF_Op<"Unpack", [NoSideEffect]> {
let summary = [{
Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.

View File

@ -110,7 +110,7 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
default:
context->ReportError(
context,
"Unique index output array can only be Int32 or In64, requested: ",
"Unique index output array can only be Int32 or In64, requested: %s",
TfLiteTypeGetName(params->index_out_type));
}
return kTfLiteError;