diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index ba2afffb019..eade8803403 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -172,6 +172,7 @@ cc_library( ":tensorflow_lite_ops_inc_gen", ":validators", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/lite/schema:schema_fbs", "@llvm//:support", "@local_config_mlir//:Analysis", "@local_config_mlir//:Dialect", @@ -303,7 +304,7 @@ genrule( srcs = [ "@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td", "@local_config_mlir//:include/mlir/IR/OpBase.td", - "//tensorflow/compiler/mlir/lite:ir/tfl_ops.td", + ":ir/tfl_ops.td", ], outs = [ "utils/generated_op_quant_spec_getters.inc", @@ -344,7 +345,7 @@ genrule( srcs = [ "@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td", "@local_config_mlir//:include/mlir/IR/OpBase.td", - "//tensorflow/compiler/mlir/lite:ir/tfl_ops.td", + ":ir/tfl_ops.td", ], outs = [ "operator_writers.inc", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 5a480ae8439..b8bbe4135e0 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -35,6 +35,16 @@ static tflite::ActivationFunctionType ConvertTFL_AFAttrForOptionWriter( .Case("SIGN_BIT", tflite::ActivationFunctionType_SIGN_BIT); } +static tflite::TensorType ConvertDerivedTFLiteTypeAttrForOptionWriter( + tflite::TensorType type, flatbuffers::FlatBufferBuilder* builder) { + if (type == tflite::TensorType_INT64) { + return tflite::TensorType_INT64; + } else if (type == tflite::TensorType_INT32) { + return tflite::TensorType_INT32; + } + llvm_unreachable("invalid type in conversion."); +} + static tflite::Padding ConvertTFL_PaddingAttrForOptionWriter( llvm::StringRef str, flatbuffers::FlatBufferBuilder* builder) { return llvm::StringSwitch(str) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 32c90099c3a..5eac0511ab7 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Support/Functional.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index bbacd16c249..e24531aa805 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -148,6 +148,7 @@ def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">; // Derived shape attribute class. //===----------------------------------------------------------------------===// class DerivedShapeAttr : DerivedAttr<"ArrayRef", body>; +class DerivedTFLiteTypeAttr : DerivedAttr<"tflite::TensorType", body>; def TFL_Int32Or64 : IntOfWidths<[32, 64]>; @@ -2059,6 +2060,35 @@ def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [ let hasOptions = 1; } +def TFL_UniqueOp: TFL_Op<"unique", [NoSideEffect]> { + let summary = "Unique Op."; + + let description = [{ + This operation returns a tensor `y` containing all of the unique elements of `x` +sorted in the same order that they occur in `x`. This operation also returns a +tensor `idx` the same size as `x` that contains the index of each value of `x` +in the unique output `y`. In other words: + }]; + + let arguments = (ins + // TODO: add uint8 support after quantize support. + TensorOf<[I8, I16, I32, I64, F32]>:$input + ); + + let results = (outs + TensorOf<[I8, I16, I32, I64, F32]>:$output, + TensorOf<[I32, I64]>:$idx + ); + + DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{ + return getResult(1)->getType().cast().getElementType(). + cast().getWidth() > 32 ? tflite::TensorType_INT64 : + tflite::TensorType_INT32; + }]>; + + let hasOptions = 1; +} + //===----------------------------------------------------------------------===// // Quantization ops. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 734d5c7a626..e3a25e07d04 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -882,3 +882,21 @@ func @cast(%arg0: tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> { // CHECK-LABEL: cast // CHECK: "tfl.cast"(%arg0) : (tensor<1x2x2x5xi32>) -> tensor<1x2x2x5xf32> } + +func @unique(%arg0: tensor<5xf32>) -> (tensor, tensor) { + %0, %1 = "tf.Unique"(%arg0) : (tensor<5xf32>) -> (tensor, tensor) + return %0, %1 : tensor , tensor + + // CHECK-LABEL: unique + // CHECK: %0:2 = "tfl.unique"(%arg0) : (tensor<5xf32>) -> (tensor, tensor) + // CHECK: %0 +} + +func @unique64(%arg0: tensor<5xf32>) -> (tensor, tensor) { + %0, %1 = "tf.Unique"(%arg0) : (tensor<5xf32>) -> (tensor, tensor) + return %0, %1 : tensor , tensor + + // CHECK-LABEL: unique64 + // CHECK: %0:2 = "tfl.unique"(%arg0) : (tensor<5xf32>) -> (tensor, tensor) + // CHECK: %0 +} diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 138338e6299..d318fa6fa9f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -270,3 +270,5 @@ def : Pat< (TFL_StridedSliceOp $input, $begin, $end, $strides, (convertIntAttrTo32Bit $begin_mask), (convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask), (convertIntAttrTo32Bit $new_axis_mask), (convertIntAttrTo32Bit $shrink_axis_mask))>; + +def : Pat<(TF_UniqueOp $arg0),(TFL_UniqueOp $arg0)>; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index a3e49b0ad51..61fd578ea11 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -2934,6 +2934,40 @@ Python Semantics. let hasCanonicalizer = 1; } +def TF_UniqueOp : TF_Op<"Unique", [NoSideEffect]> { + let summary = "Finds unique elements in a 1-D tensor."; + + let description = [{ +This operation returns a tensor `y` containing all of the unique elements of `x` +sorted in the same order that they occur in `x`. This operation also returns a +tensor `idx` the same size as `x` that contains the index of each value of `x` +in the unique output `y`. In other words: + +`y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]` + +For example: + +``` +# tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8] +y, idx = unique(x) +y ==> [1, 2, 4, 7, 8] +idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] +``` + }]; + + let arguments = (ins + TF_Tensor:$x + ); + + let results = (outs + TF_Tensor:$y, + TF_I32OrI64Tensor:$idx + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_idx = TF_DerivedResultTypeAttr<1>; +} + def TF_UnpackOp : TF_Op<"Unpack", [NoSideEffect]> { let summary = [{ Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. diff --git a/tensorflow/lite/kernels/unique.cc b/tensorflow/lite/kernels/unique.cc index 80c033aa5ce..1054f7f7535 100644 --- a/tensorflow/lite/kernels/unique.cc +++ b/tensorflow/lite/kernels/unique.cc @@ -110,7 +110,7 @@ TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input, default: context->ReportError( context, - "Unique index output array can only be Int32 or In64, requested: ", + "Unique index output array can only be Int32 or In64, requested: %s", TfLiteTypeGetName(params->index_out_type)); } return kTfLiteError;