Add Unique Op implementation
PiperOrigin-RevId: 228202673
This commit is contained in:
parent
68340a6445
commit
2c0a8c1647
@ -311,6 +311,7 @@ def generated_test_models():
|
||||
"topk",
|
||||
"transpose",
|
||||
"transpose_conv",
|
||||
"unique",
|
||||
"unpack",
|
||||
"unroll_batch_matmul",
|
||||
"where",
|
||||
|
@ -221,6 +221,7 @@ cc_library(
|
||||
"transpose_conv.cc",
|
||||
"unidirectional_sequence_lstm.cc",
|
||||
"unidirectional_sequence_rnn.cc",
|
||||
"unique.cc",
|
||||
"unpack.cc",
|
||||
"zeros_like.cc",
|
||||
],
|
||||
@ -1233,6 +1234,17 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "unique_test",
|
||||
srcs = ["unique_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -129,6 +129,7 @@ TfLiteRegistration* Register_LEAKY_RELU();
|
||||
TfLiteRegistration* Register_SQUARED_DIFFERENCE();
|
||||
TfLiteRegistration* Register_FILL();
|
||||
TfLiteRegistration* Register_MIRROR_PAD();
|
||||
TfLiteRegistration* Register_UNIQUE();
|
||||
|
||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ReportError(
|
||||
@ -284,6 +285,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE());
|
||||
AddBuiltin(BuiltinOperator_FILL, Register_FILL());
|
||||
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
|
||||
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
164
tensorflow/lite/kernels/unique.cc
Normal file
164
tensorflow/lite/kernels/unique.cc
Normal file
@ -0,0 +1,164 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace unique {
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
static const int kOutputUniqueTensor = 0;
|
||||
static const int kOutputIndexTensor = 1;
|
||||
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output_unique_tensor =
|
||||
GetOutput(context, node, kOutputUniqueTensor);
|
||||
TfLiteTensor* output_index_tensor =
|
||||
GetOutput(context, node, kOutputIndexTensor);
|
||||
|
||||
// The op only supports 1D input.
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
|
||||
TfLiteIntArray* output_index_shape = TfLiteIntArrayCopy(input->dims);
|
||||
// The unique values are determined during evaluation, so we don't know yet
|
||||
// the size of the output tensor.
|
||||
SetTensorToDynamic(output_unique_tensor);
|
||||
return context->ResizeTensor(context, output_index_tensor,
|
||||
output_index_shape);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Actual evaluation for the unique op.
|
||||
template <typename T, typename I>
|
||||
TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
|
||||
TfLiteNode* node) {
|
||||
// Map from value, to index in the unique elements vector.
|
||||
// Note that we prefer to use map than unordered_map as it showed less
|
||||
// increase in the binary size.
|
||||
std::map<T, int> unique_values;
|
||||
TfLiteTensor* output_indexes = GetOutput(context, node, 1);
|
||||
I* indexes = GetTensorData<I>(output_indexes);
|
||||
const T* data = GetTensorData<T>(input);
|
||||
const int num_elements = NumElements(input);
|
||||
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
const auto element_it = unique_values.find(data[i]);
|
||||
if (element_it != unique_values.end()) {
|
||||
indexes[i] = element_it->second;
|
||||
} else {
|
||||
const int unique_index = unique_values.size();
|
||||
unique_values[data[i]] = unique_index;
|
||||
indexes[i] = unique_index;
|
||||
}
|
||||
}
|
||||
// Allocate output tensor.
|
||||
TfLiteTensor* unique_output = GetOutput(context, node, 0);
|
||||
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
|
||||
TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree);
|
||||
shape->data[0] = unique_values.size();
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
context->ResizeTensor(context, unique_output, shape.release()));
|
||||
// Set the values in the output tensor.
|
||||
T* output_unique_values = GetTensorData<T>(unique_output);
|
||||
for (int i = 0; i < unique_values.size(); ++i) {
|
||||
output_unique_values[i] = data[indexes[i]];
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
|
||||
TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteUniqueParams*>(node->builtin_data);
|
||||
if (params == nullptr) {
|
||||
context->ReportError(context, "Null params passed");
|
||||
return kTfLiteError;
|
||||
}
|
||||
switch (params->index_out_type) {
|
||||
case kTfLiteInt32:
|
||||
return EvalImpl<T, int32_t>(context, input, node);
|
||||
case kTfLiteInt64:
|
||||
return EvalImpl<T, int64_t>(context, input, node);
|
||||
default:
|
||||
context->ReportError(
|
||||
context,
|
||||
"Unique index output array can only be Int32 or In64, requested: ",
|
||||
TfLiteTypeGetName(params->index_out_type));
|
||||
}
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||
TfLiteTensor* output_index_tensor = GetOutput(context, node, 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor),
|
||||
NumElements(input));
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteInt8:
|
||||
TF_LITE_ENSURE_STATUS(EvalImpl<int8_t>(context, input, node));
|
||||
break;
|
||||
case kTfLiteInt16:
|
||||
TF_LITE_ENSURE_STATUS(EvalImpl<int16_t>(context, input, node));
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
TF_LITE_ENSURE_STATUS(EvalImpl<int32_t>(context, input, node));
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
TF_LITE_ENSURE_STATUS(EvalImpl<int64_t>(context, input, node));
|
||||
break;
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_ENSURE_STATUS(EvalImpl<float>(context, input, node));
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
TF_LITE_ENSURE_STATUS(EvalImpl<uint8_t>(context, input, node));
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context, "Currently Unique doesn't support type: %s",
|
||||
TfLiteTypeGetName(input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace unique
|
||||
|
||||
TfLiteRegistration* Register_UNIQUE() {
|
||||
static TfLiteRegistration r = {unique::Init, unique::Free, unique::Prepare,
|
||||
unique::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
103
tensorflow/lite/kernels/unique_test.cc
Normal file
103
tensorflow/lite/kernels/unique_test.cc
Normal file
@ -0,0 +1,103 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <typename T, typename I>
|
||||
class UniqueOpModel : public SingleOpModel {
|
||||
public:
|
||||
UniqueOpModel(const TensorData& input, TensorType input_type,
|
||||
TensorType index_out_type) {
|
||||
input_id_ = AddInput(input);
|
||||
output_id_ = AddOutput(input_type);
|
||||
output_index_id_ = AddOutput(index_out_type);
|
||||
SetBuiltinOp(BuiltinOperator_UNIQUE, BuiltinOptions_UniqueOptions,
|
||||
CreateUniqueOptions(builder_, index_out_type).Union());
|
||||
BuildInterpreter({GetShape(input_id_)});
|
||||
}
|
||||
|
||||
int input_tensor_id() { return input_id_; }
|
||||
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); }
|
||||
std::vector<I> GetIndexesOutput() {
|
||||
return ExtractVector<I>(output_index_id_);
|
||||
}
|
||||
|
||||
protected:
|
||||
int input_id_;
|
||||
int output_id_;
|
||||
int output_index_id_;
|
||||
};
|
||||
|
||||
TEST(UniqueOpModelTest, OneElement) {
|
||||
UniqueOpModel<float, int32_t> model({TensorType_FLOAT32, {1}},
|
||||
TensorType_FLOAT32, TensorType_INT32);
|
||||
model.PopulateTensor<float>(model.input_tensor_id(), {5});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
|
||||
EXPECT_THAT(model.GetIndexesOutput(), ElementsAreArray({0}));
|
||||
}
|
||||
|
||||
TEST(UniqueOpModelTest, MultipleElements_AllUnique) {
|
||||
UniqueOpModel<float, int32_t> model({TensorType_FLOAT32, {8}},
|
||||
TensorType_FLOAT32, TensorType_INT32);
|
||||
model.PopulateTensor<float>(model.input_tensor_id(),
|
||||
{5, 2, 3, 51, 6, 72, 7, 8});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 3, 51, 6, 72, 7, 8}));
|
||||
EXPECT_THAT(model.GetIndexesOutput(),
|
||||
ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7}));
|
||||
}
|
||||
|
||||
TEST(UniqueOpModelTest, MultipleElements_AllDuplicates) {
|
||||
UniqueOpModel<float, int32_t> model({TensorType_FLOAT32, {7}},
|
||||
TensorType_FLOAT32, TensorType_INT32);
|
||||
model.PopulateTensor<float>(model.input_tensor_id(), {5, 5, 5, 5, 5, 5, 5});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
|
||||
EXPECT_THAT(model.GetIndexesOutput(),
|
||||
ElementsAreArray({0, 0, 0, 0, 0, 0, 0}));
|
||||
}
|
||||
|
||||
TEST(UniqueOpModelTest, MultipleElements_SomeDuplicates) {
|
||||
UniqueOpModel<float, int32_t> model({TensorType_FLOAT32, {7}},
|
||||
TensorType_FLOAT32, TensorType_INT32);
|
||||
model.PopulateTensor<float>(model.input_tensor_id(), {2, 3, 5, 7, 2, 7, 3});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 5, 7}));
|
||||
EXPECT_THAT(model.GetIndexesOutput(),
|
||||
ElementsAreArray({0, 1, 2, 3, 0, 3, 1}));
|
||||
}
|
||||
|
||||
TEST(UniqueOpModelTest, MultipleElements_SomeDuplicates_IndexInt64) {
|
||||
UniqueOpModel<float, int64_t> model({TensorType_FLOAT32, {7}},
|
||||
TensorType_FLOAT32, TensorType_INT64);
|
||||
model.PopulateTensor<float>(model.input_tensor_id(), {2, 3, 5, 7, 2, 7, 3});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 5, 7}));
|
||||
EXPECT_THAT(model.GetIndexesOutput(),
|
||||
ElementsAreArray({0, 1, 2, 3, 0, 3, 1}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
@ -3749,6 +3749,55 @@ def make_placeholder_with_default_tests(zip_path):
|
||||
expected_tf_success=3)
|
||||
|
||||
|
||||
def make_unique_tests(zip_path):
|
||||
"""Make a set of tests for Unique op."""
|
||||
|
||||
test_parameters = [
|
||||
{
|
||||
"input_shape": [[1]],
|
||||
"index_type": [tf.int32, tf.int64, None],
|
||||
"input_values": [3]
|
||||
},
|
||||
{
|
||||
"input_shape": [[5]],
|
||||
"index_type": [tf.int32, tf.int64],
|
||||
"input_values": [[3, 2, 1, 2, 3]]
|
||||
},
|
||||
{
|
||||
"input_shape": [[7]],
|
||||
"index_type": [tf.int32, tf.int64],
|
||||
"input_values": [[1, 1, 1, 1, 1, 1, 1]]
|
||||
},
|
||||
{
|
||||
"input_shape": [[5]],
|
||||
"index_type": [tf.int32, tf.int64],
|
||||
"input_values": [[3, 2, 1, 0, -1]]
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the graph for the test case."""
|
||||
|
||||
input_tensor = tf.placeholder(
|
||||
dtype=tf.int32, name="input", shape=parameters["input_shape"])
|
||||
if parameters["index_type"] is None:
|
||||
output = tf.unique(input_tensor)
|
||||
else:
|
||||
output = tf.unique(input_tensor, parameters["index_type"])
|
||||
|
||||
return [input_tensor], output
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_values = [create_tensor_data(tf.int32, parameters["input_shape"])]
|
||||
return input_values, sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, input_values)))
|
||||
|
||||
make_zip_of_tests(
|
||||
zip_path,
|
||||
test_parameters,
|
||||
build_graph,
|
||||
build_inputs,
|
||||
expected_tf_success=9)
|
||||
|
||||
# Toco binary path provided by the generate rule.
|
||||
bin_path = None
|
||||
|
||||
|
@ -252,6 +252,14 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
|
||||
SetDataTypeForAllOutputs(model, op, data_type);
|
||||
break;
|
||||
}
|
||||
case OperatorType::kUnique: {
|
||||
CHECK_EQ(op->outputs.size(), 2);
|
||||
const UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
|
||||
const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
|
||||
model->GetArray(op->outputs[0]).data_type = data_type;
|
||||
model->GetArray(op->outputs[1]).data_type = unique_op->idx_out_type;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// These operators produce outputs with the same type as their 1st input
|
||||
CHECK_GT(op->inputs.size(), 0);
|
||||
|
@ -1828,6 +1828,20 @@ void ProcessMirrorPadOperator(Model* model, MirrorPadOperator* op) {
|
||||
output_array.copy_shape(output_shape);
|
||||
}
|
||||
|
||||
void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
|
||||
const auto& input_array = model->GetArray(op->inputs[0]);
|
||||
// We have 2 outputs, the shape of the index tensor, is the same size
|
||||
// as the input array. The unique values tensor, is unknown until runtime.
|
||||
CHECK_EQ(op->outputs.size(), 2);
|
||||
auto& idx_output_array = model->GetArray(op->outputs[1]);
|
||||
|
||||
// Yield until input dims have been resolved, or output already computed
|
||||
if (!input_array.has_shape() || idx_output_array.has_shape()) {
|
||||
return;
|
||||
}
|
||||
idx_output_array.copy_shape(input_array.shape());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
|
||||
@ -2103,6 +2117,9 @@ void ProcessMirrorPadOperator(Model* model, MirrorPadOperator* op) {
|
||||
case OperatorType::kMirrorPad:
|
||||
ProcessMirrorPadOperator(model, static_cast<MirrorPadOperator*>(op));
|
||||
break;
|
||||
case OperatorType::kUnique:
|
||||
ProcessUniqueOperator(model, static_cast<UniqueOperator*>(op));
|
||||
break;
|
||||
default:
|
||||
// Unimplemented, another graph transformation should drop it.
|
||||
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
|
||||
|
@ -1190,7 +1190,7 @@ enum FlexSupport { kFlexOk, kFlexNotOk };
|
||||
// taken from the given NodeDef, and its number must match NumInputs, unless
|
||||
// kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator
|
||||
// will be eligible for being exported as a flex op.
|
||||
template <typename Op, int NumInputs, FlexSupport flex>
|
||||
template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
|
||||
tensorflow::Status ConvertSimpleOperatorGeneric(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
@ -1203,6 +1203,11 @@ tensorflow::Status ConvertSimpleOperatorGeneric(
|
||||
op->inputs.push_back(node.input(i));
|
||||
}
|
||||
op->outputs.push_back(node.name());
|
||||
if (NumOutputs > 1) {
|
||||
for (int i = 1; i < NumOutputs; ++i) {
|
||||
op->outputs.push_back(node.name() + ":" + std::to_string(i));
|
||||
}
|
||||
}
|
||||
|
||||
if (flex == kFlexOk) {
|
||||
RetainTensorFlowNodeDef(node, op);
|
||||
@ -1213,20 +1218,20 @@ tensorflow::Status ConvertSimpleOperatorGeneric(
|
||||
}
|
||||
|
||||
// Convert a simple operator which is not valid as a flex op.
|
||||
template <typename Op, int NumInputs = kAnyNumInputs>
|
||||
template <typename Op, int NumInputs, int NumOutputs>
|
||||
tensorflow::Status ConvertSimpleOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
return ConvertSimpleOperatorGeneric<Op, NumInputs, kFlexNotOk>(
|
||||
return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
|
||||
node, tf_import_flags, model);
|
||||
}
|
||||
|
||||
// Convert a simple operator which is valid as a flex op.
|
||||
template <typename Op, int NumInputs = kAnyNumInputs>
|
||||
template <typename Op, int NumInputs, int NumOutputs>
|
||||
tensorflow::Status ConvertSimpleOperatorFlexOk(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
return ConvertSimpleOperatorGeneric<Op, NumInputs, kFlexOk>(
|
||||
return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
|
||||
node, tf_import_flags, model);
|
||||
}
|
||||
|
||||
@ -2333,14 +2338,15 @@ ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
|
||||
|
||||
ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
return std::unordered_map<std::string, ConverterType>({
|
||||
{"Abs", ConvertSimpleOperator<AbsOperator>},
|
||||
{"Add", ConvertSimpleOperator<AddOperator, 2>},
|
||||
{"AddN", ConvertSimpleOperatorFlexOk<AddNOperator>},
|
||||
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
|
||||
{"Abs", ConvertSimpleOperator<AbsOperator, kAnyNumInputs, 1>},
|
||||
{"Add", ConvertSimpleOperator<AddOperator, 2, 1>},
|
||||
{"AddN", ConvertSimpleOperatorFlexOk<AddNOperator, kAnyNumInputs, 1>},
|
||||
{"All", ConvertSimpleOperator<TensorFlowAllOperator, kAnyNumInputs, 1>},
|
||||
{"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
|
||||
{"ArgMax", ConvertArgMaxOperator},
|
||||
{"ArgMin", ConvertArgMinOperator},
|
||||
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
|
||||
{"Assert",
|
||||
ConvertSimpleOperator<TensorFlowAssertOperator, kAnyNumInputs, 1>},
|
||||
{"AvgPool", ConvertAvgPoolOperator},
|
||||
{"BatchMatMul", ConvertBatchMatMulOperator},
|
||||
{"BatchNormWithGlobalNormalization",
|
||||
@ -2357,98 +2363,99 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
|
||||
{"DepthToSpace", ConvertDepthToSpaceOperator},
|
||||
{"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
|
||||
{"Div", ConvertSimpleOperator<DivOperator, 2>},
|
||||
{"Div", ConvertSimpleOperator<DivOperator, 2, 1>},
|
||||
{"DynamicPartition", ConvertDynamicPartitionOperator},
|
||||
{"DynamicStitch", ConvertDynamicStitchOperator},
|
||||
{"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2>},
|
||||
{"Exp", ConvertSimpleOperator<ExpOperator, 1>},
|
||||
{"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2>},
|
||||
{"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2, 1>},
|
||||
{"Exp", ConvertSimpleOperator<ExpOperator, 1, 1>},
|
||||
{"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2, 1>},
|
||||
{"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
|
||||
{"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
|
||||
{"Fill", ConvertSimpleOperator<FillOperator, 2>},
|
||||
{"Fill", ConvertSimpleOperator<FillOperator, 2, 1>},
|
||||
{"Floor", ConvertFloorOperator},
|
||||
{"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2>},
|
||||
{"FloorMod", ConvertSimpleOperator<FloorModOperator, 2>},
|
||||
{"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2, 1>},
|
||||
{"FloorMod", ConvertSimpleOperator<FloorModOperator, 2, 1>},
|
||||
{"FusedBatchNorm", ConvertFusedBatchNormOperator},
|
||||
{"Gather", ConvertGatherOperator},
|
||||
{"GatherV2", ConvertGatherOperator},
|
||||
{"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2>},
|
||||
{"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2, 1>},
|
||||
{"GreaterEqual",
|
||||
ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2>},
|
||||
ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2, 1>},
|
||||
{"Identity", ConvertIdentityOperator},
|
||||
{"LRN", ConvertLRNOperator},
|
||||
{"LeakyRelu", ConvertLeakyReluOperator},
|
||||
{"LegacyFedInput", ConvertPlaceholderOperator},
|
||||
{"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>},
|
||||
{"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>},
|
||||
{"Log", ConvertSimpleOperator<LogOperator, 1>},
|
||||
{"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>},
|
||||
{"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2>},
|
||||
{"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>},
|
||||
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
|
||||
{"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2, 1>},
|
||||
{"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2, 1>},
|
||||
{"Log", ConvertSimpleOperator<LogOperator, 1, 1>},
|
||||
{"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2, 1>},
|
||||
{"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2, 1>},
|
||||
{"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
|
||||
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
|
||||
{"MatMul", ConvertMatMulOperator},
|
||||
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
|
||||
{"MaxPool", ConvertMaxPoolOperator},
|
||||
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>},
|
||||
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
|
||||
{"Mean", ConvertReduceOperator<MeanOperator>},
|
||||
{"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>},
|
||||
{"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2, 1>},
|
||||
{"Min", ConvertReduceOperator<TensorFlowMinOperator>},
|
||||
{"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>},
|
||||
{"Mul", ConvertSimpleOperator<MulOperator, 2>},
|
||||
{"Neg", ConvertSimpleOperator<NegOperator, 1>},
|
||||
{"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2, 1>},
|
||||
{"Mul", ConvertSimpleOperator<MulOperator, 2, 1>},
|
||||
{"Neg", ConvertSimpleOperator<NegOperator, 1, 1>},
|
||||
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
|
||||
{"NoOp", ConvertNoOpOperator},
|
||||
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
|
||||
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2, 1>},
|
||||
{"OneHot", ConvertOneHotOperator},
|
||||
{"Pack", ConvertPackOperator},
|
||||
{"Pad", ConvertSimpleOperator<PadOperator, 2>},
|
||||
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},
|
||||
{"Pad", ConvertSimpleOperator<PadOperator, 2, 1>},
|
||||
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3, 1>},
|
||||
{"ParallelDynamicStitch", ConvertDynamicStitchOperator},
|
||||
{"Placeholder", ConvertPlaceholderOperator},
|
||||
{"PlaceholderWithDefault", ConvertIdentityOperator},
|
||||
{"Pow", ConvertSimpleOperator<PowOperator, 2>},
|
||||
{"Pow", ConvertSimpleOperator<PowOperator, 2, 1>},
|
||||
{"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
|
||||
{"RandomUniform", ConvertRandomUniform},
|
||||
{"Range", ConvertRangeOperator},
|
||||
{"Rank", ConvertSimpleOperator<RankOperator, 1>},
|
||||
{"RealDiv", ConvertSimpleOperator<DivOperator, 2>},
|
||||
{"Relu", ConvertSimpleOperator<ReluOperator, 1>},
|
||||
{"Relu6", ConvertSimpleOperator<Relu6Operator, 1>},
|
||||
{"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2>},
|
||||
{"Rank", ConvertSimpleOperator<RankOperator, 1, 1>},
|
||||
{"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
|
||||
{"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
|
||||
{"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},
|
||||
{"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
|
||||
{"ResizeBilinear", ConvertResizeBilinearOperator},
|
||||
{"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
|
||||
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1>},
|
||||
{"Select", ConvertSimpleOperator<SelectOperator, 3>},
|
||||
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
|
||||
{"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
|
||||
{"Shape", ConvertShapeOperator},
|
||||
{"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1>},
|
||||
{"Sin", ConvertSimpleOperator<SinOperator, 1>},
|
||||
{"Slice", ConvertSimpleOperator<SliceOperator, 3>},
|
||||
{"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1, 1>},
|
||||
{"Sin", ConvertSimpleOperator<SinOperator, 1, 1>},
|
||||
{"Slice", ConvertSimpleOperator<SliceOperator, 3, 1>},
|
||||
{"Softmax", ConvertSoftmaxOperator},
|
||||
{"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
|
||||
{"SpaceToDepth", ConvertSpaceToDepthOperator},
|
||||
{"SparseToDense", ConvertSparseToDenseOperator},
|
||||
{"Split", ConvertSplitOperator},
|
||||
{"SplitV", ConvertSplitVOperator},
|
||||
{"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1>},
|
||||
{"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1>},
|
||||
{"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1, 1>},
|
||||
{"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
|
||||
{"SquaredDifference",
|
||||
ConvertSimpleOperator<SquaredDifferenceOperator, 2>},
|
||||
ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
|
||||
{"Squeeze", ConvertSqueezeOperator},
|
||||
{"StopGradient", ConvertIdentityOperator},
|
||||
{"StridedSlice", ConvertStridedSliceOperator},
|
||||
{"Sub", ConvertSimpleOperator<SubOperator, 2>},
|
||||
{"Sub", ConvertSimpleOperator<SubOperator, 2, 1>},
|
||||
{"Sum", ConvertReduceOperator<TensorFlowSumOperator>},
|
||||
{"Svdf", ConvertSvdfOperator},
|
||||
{"Switch", ConvertSwitchOperator},
|
||||
{"Tanh", ConvertSimpleOperator<TanhOperator, 1>},
|
||||
{"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2>},
|
||||
{"Tanh", ConvertSimpleOperator<TanhOperator, 1, 1>},
|
||||
{"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2, 1>},
|
||||
{"TopK", ConvertTopKV2Operator},
|
||||
{"TopKV2", ConvertTopKV2Operator},
|
||||
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
|
||||
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2, 1>},
|
||||
{"Unpack", ConvertUnpackOperator},
|
||||
{"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>},
|
||||
{"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1, 1>},
|
||||
{"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
|
||||
{"MirrorPad", ConvertMirrorPadOperator},
|
||||
{"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -157,7 +157,8 @@ enum class OperatorType : uint8 {
|
||||
kResizeNearestNeighbor,
|
||||
kLeakyRelu,
|
||||
kAbs,
|
||||
kMirrorPad
|
||||
kMirrorPad,
|
||||
kUnique
|
||||
};
|
||||
|
||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||
@ -1953,6 +1954,17 @@ struct MirrorPadOperator : Operator {
|
||||
MirrorPadMode mode;
|
||||
};
|
||||
|
||||
// Unique Operator:
|
||||
//
|
||||
// Inputs:
|
||||
// inputs[0]: required: the input array
|
||||
//
|
||||
// TensorFlow equivalent: Unique
|
||||
struct UniqueOperator : Operator {
|
||||
UniqueOperator() : Operator(OperatorType::kUnique) {}
|
||||
ArrayDataType idx_out_type = ArrayDataType::kInt32;
|
||||
};
|
||||
|
||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
|
||||
// graph_transformation module.
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
|
||||
@ -1478,6 +1479,31 @@ class MirrorPad
|
||||
: MirrorPadMode::kSymmetric;
|
||||
}
|
||||
|
||||
int GetVersion(const OperatorSignature& op) const override { return 1; }
|
||||
};
|
||||
|
||||
class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
|
||||
::tflite::BuiltinOptions_UniqueOptions> {
|
||||
public:
|
||||
using BuiltinOperator::BuiltinOperator;
|
||||
flatbuffers::Offset<TfLiteOptions> WriteOptions(
|
||||
const TocoOperator& op,
|
||||
flatbuffers::FlatBufferBuilder* builder) const override {
|
||||
const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
|
||||
return ::tflite::CreateUniqueOptions(
|
||||
*builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
|
||||
? ::tflite::TensorType::TensorType_INT64
|
||||
: ::tflite::TensorType_INT32);
|
||||
}
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {
|
||||
UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
|
||||
unique_op->idx_out_type =
|
||||
options.idx_out_type() == ::tflite::TensorType_INT64
|
||||
? toco::ArrayDataType::kInt64
|
||||
: toco::ArrayDataType::kInt32;
|
||||
}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
@ -1819,6 +1845,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
OperatorType::kSquaredDifference));
|
||||
ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
|
||||
OperatorType::kMirrorPad));
|
||||
ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
|
||||
OperatorType::kUnique));
|
||||
|
||||
// Custom Operators.
|
||||
ops.push_back(
|
||||
|
@ -629,6 +629,15 @@ TEST_F(OperatorTest, BuiltinMirrorPad) {
|
||||
EXPECT_EQ(op.mode, output_toco_op->mode);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinUnique) {
|
||||
UniqueOperator op;
|
||||
op.idx_out_type = ArrayDataType::kInt64;
|
||||
auto output_toco_op =
|
||||
SerializeAndDeserialize(GetOperator("UNIQUE", OperatorType::kUnique), op);
|
||||
ASSERT_NE(nullptr, output_toco_op.get());
|
||||
EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -416,6 +416,7 @@ const char* OperatorTypeName(OperatorType type) {
|
||||
HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
|
||||
HANDLE_OPERATORTYPENAME_CASE(SquaredDifference)
|
||||
HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Unique)
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled op type";
|
||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||
|
Loading…
x
Reference in New Issue
Block a user