Add Unique Op implementation
PiperOrigin-RevId: 228202673
This commit is contained in:
parent
68340a6445
commit
2c0a8c1647
@ -311,6 +311,7 @@ def generated_test_models():
|
|||||||
"topk",
|
"topk",
|
||||||
"transpose",
|
"transpose",
|
||||||
"transpose_conv",
|
"transpose_conv",
|
||||||
|
"unique",
|
||||||
"unpack",
|
"unpack",
|
||||||
"unroll_batch_matmul",
|
"unroll_batch_matmul",
|
||||||
"where",
|
"where",
|
||||||
|
@ -221,6 +221,7 @@ cc_library(
|
|||||||
"transpose_conv.cc",
|
"transpose_conv.cc",
|
||||||
"unidirectional_sequence_lstm.cc",
|
"unidirectional_sequence_lstm.cc",
|
||||||
"unidirectional_sequence_rnn.cc",
|
"unidirectional_sequence_rnn.cc",
|
||||||
|
"unique.cc",
|
||||||
"unpack.cc",
|
"unpack.cc",
|
||||||
"zeros_like.cc",
|
"zeros_like.cc",
|
||||||
],
|
],
|
||||||
@ -1233,6 +1234,17 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "unique_test",
|
||||||
|
srcs = ["unique_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":builtin_ops",
|
||||||
|
":test_util",
|
||||||
|
"//tensorflow/lite:framework",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "all_files",
|
name = "all_files",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
|
@ -129,6 +129,7 @@ TfLiteRegistration* Register_LEAKY_RELU();
|
|||||||
TfLiteRegistration* Register_SQUARED_DIFFERENCE();
|
TfLiteRegistration* Register_SQUARED_DIFFERENCE();
|
||||||
TfLiteRegistration* Register_FILL();
|
TfLiteRegistration* Register_FILL();
|
||||||
TfLiteRegistration* Register_MIRROR_PAD();
|
TfLiteRegistration* Register_MIRROR_PAD();
|
||||||
|
TfLiteRegistration* Register_UNIQUE();
|
||||||
|
|
||||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
@ -284,6 +285,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE());
|
AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE());
|
||||||
AddBuiltin(BuiltinOperator_FILL, Register_FILL());
|
AddBuiltin(BuiltinOperator_FILL, Register_FILL());
|
||||||
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
|
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
|
||||||
|
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
|
||||||
|
|
||||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||||
// custom ops aren't always included by default.
|
// custom ops aren't always included by default.
|
||||||
|
164
tensorflow/lite/kernels/unique.cc
Normal file
164
tensorflow/lite/kernels/unique.cc
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace ops {
|
||||||
|
namespace builtin {
|
||||||
|
namespace unique {
|
||||||
|
|
||||||
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Free(TfLiteContext* context, void* buffer) {}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
static const int kOutputUniqueTensor = 0;
|
||||||
|
static const int kOutputIndexTensor = 1;
|
||||||
|
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||||
|
TfLiteTensor* output_unique_tensor =
|
||||||
|
GetOutput(context, node, kOutputUniqueTensor);
|
||||||
|
TfLiteTensor* output_index_tensor =
|
||||||
|
GetOutput(context, node, kOutputIndexTensor);
|
||||||
|
|
||||||
|
// The op only supports 1D input.
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
|
||||||
|
TfLiteIntArray* output_index_shape = TfLiteIntArrayCopy(input->dims);
|
||||||
|
// The unique values are determined during evaluation, so we don't know yet
|
||||||
|
// the size of the output tensor.
|
||||||
|
SetTensorToDynamic(output_unique_tensor);
|
||||||
|
return context->ResizeTensor(context, output_index_tensor,
|
||||||
|
output_index_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Actual evaluation for the unique op.
|
||||||
|
template <typename T, typename I>
|
||||||
|
TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
|
||||||
|
TfLiteNode* node) {
|
||||||
|
// Map from value, to index in the unique elements vector.
|
||||||
|
// Note that we prefer to use map than unordered_map as it showed less
|
||||||
|
// increase in the binary size.
|
||||||
|
std::map<T, int> unique_values;
|
||||||
|
TfLiteTensor* output_indexes = GetOutput(context, node, 1);
|
||||||
|
I* indexes = GetTensorData<I>(output_indexes);
|
||||||
|
const T* data = GetTensorData<T>(input);
|
||||||
|
const int num_elements = NumElements(input);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_elements; ++i) {
|
||||||
|
const auto element_it = unique_values.find(data[i]);
|
||||||
|
if (element_it != unique_values.end()) {
|
||||||
|
indexes[i] = element_it->second;
|
||||||
|
} else {
|
||||||
|
const int unique_index = unique_values.size();
|
||||||
|
unique_values[data[i]] = unique_index;
|
||||||
|
indexes[i] = unique_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Allocate output tensor.
|
||||||
|
TfLiteTensor* unique_output = GetOutput(context, node, 0);
|
||||||
|
std::unique_ptr<TfLiteIntArray, void (*)(TfLiteIntArray*)> shape(
|
||||||
|
TfLiteIntArrayCreate(NumDimensions(input)), TfLiteIntArrayFree);
|
||||||
|
shape->data[0] = unique_values.size();
|
||||||
|
TF_LITE_ENSURE_STATUS(
|
||||||
|
context->ResizeTensor(context, unique_output, shape.release()));
|
||||||
|
// Set the values in the output tensor.
|
||||||
|
T* output_unique_values = GetTensorData<T>(unique_output);
|
||||||
|
for (int i = 0; i < unique_values.size(); ++i) {
|
||||||
|
output_unique_values[i] = data[indexes[i]];
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
TfLiteStatus EvalImpl(TfLiteContext* context, const TfLiteTensor* input,
|
||||||
|
TfLiteNode* node) {
|
||||||
|
auto* params = reinterpret_cast<TfLiteUniqueParams*>(node->builtin_data);
|
||||||
|
if (params == nullptr) {
|
||||||
|
context->ReportError(context, "Null params passed");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
switch (params->index_out_type) {
|
||||||
|
case kTfLiteInt32:
|
||||||
|
return EvalImpl<T, int32_t>(context, input, node);
|
||||||
|
case kTfLiteInt64:
|
||||||
|
return EvalImpl<T, int64_t>(context, input, node);
|
||||||
|
default:
|
||||||
|
context->ReportError(
|
||||||
|
context,
|
||||||
|
"Unique index output array can only be Int32 or In64, requested: ",
|
||||||
|
TfLiteTypeGetName(params->index_out_type));
|
||||||
|
}
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteTensor* input = GetInput(context, node, 0);
|
||||||
|
TfLiteTensor* output_index_tensor = GetOutput(context, node, 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumElements(output_index_tensor),
|
||||||
|
NumElements(input));
|
||||||
|
|
||||||
|
switch (input->type) {
|
||||||
|
case kTfLiteInt8:
|
||||||
|
TF_LITE_ENSURE_STATUS(EvalImpl<int8_t>(context, input, node));
|
||||||
|
break;
|
||||||
|
case kTfLiteInt16:
|
||||||
|
TF_LITE_ENSURE_STATUS(EvalImpl<int16_t>(context, input, node));
|
||||||
|
break;
|
||||||
|
case kTfLiteInt32:
|
||||||
|
TF_LITE_ENSURE_STATUS(EvalImpl<int32_t>(context, input, node));
|
||||||
|
break;
|
||||||
|
case kTfLiteInt64:
|
||||||
|
TF_LITE_ENSURE_STATUS(EvalImpl<int64_t>(context, input, node));
|
||||||
|
break;
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
TF_LITE_ENSURE_STATUS(EvalImpl<float>(context, input, node));
|
||||||
|
break;
|
||||||
|
case kTfLiteUInt8:
|
||||||
|
TF_LITE_ENSURE_STATUS(EvalImpl<uint8_t>(context, input, node));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
context->ReportError(context, "Currently Unique doesn't support type: %s",
|
||||||
|
TfLiteTypeGetName(input->type));
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace unique
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_UNIQUE() {
|
||||||
|
static TfLiteRegistration r = {unique::Init, unique::Free, unique::Prepare,
|
||||||
|
unique::Eval};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace builtin
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tflite
|
103
tensorflow/lite/kernels/unique_test.cc
Normal file
103
tensorflow/lite/kernels/unique_test.cc
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
class UniqueOpModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
UniqueOpModel(const TensorData& input, TensorType input_type,
|
||||||
|
TensorType index_out_type) {
|
||||||
|
input_id_ = AddInput(input);
|
||||||
|
output_id_ = AddOutput(input_type);
|
||||||
|
output_index_id_ = AddOutput(index_out_type);
|
||||||
|
SetBuiltinOp(BuiltinOperator_UNIQUE, BuiltinOptions_UniqueOptions,
|
||||||
|
CreateUniqueOptions(builder_, index_out_type).Union());
|
||||||
|
BuildInterpreter({GetShape(input_id_)});
|
||||||
|
}
|
||||||
|
|
||||||
|
int input_tensor_id() { return input_id_; }
|
||||||
|
|
||||||
|
std::vector<T> GetOutput() { return ExtractVector<T>(output_id_); }
|
||||||
|
std::vector<I> GetIndexesOutput() {
|
||||||
|
return ExtractVector<I>(output_index_id_);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int input_id_;
|
||||||
|
int output_id_;
|
||||||
|
int output_index_id_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(UniqueOpModelTest, OneElement) {
|
||||||
|
UniqueOpModel<float, int32_t> model({TensorType_FLOAT32, {1}},
|
||||||
|
TensorType_FLOAT32, TensorType_INT32);
|
||||||
|
model.PopulateTensor<float>(model.input_tensor_id(), {5});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
|
||||||
|
EXPECT_THAT(model.GetIndexesOutput(), ElementsAreArray({0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UniqueOpModelTest, MultipleElements_AllUnique) {
|
||||||
|
UniqueOpModel<float, int32_t> model({TensorType_FLOAT32, {8}},
|
||||||
|
TensorType_FLOAT32, TensorType_INT32);
|
||||||
|
model.PopulateTensor<float>(model.input_tensor_id(),
|
||||||
|
{5, 2, 3, 51, 6, 72, 7, 8});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5, 2, 3, 51, 6, 72, 7, 8}));
|
||||||
|
EXPECT_THAT(model.GetIndexesOutput(),
|
||||||
|
ElementsAreArray({0, 1, 2, 3, 4, 5, 6, 7}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UniqueOpModelTest, MultipleElements_AllDuplicates) {
|
||||||
|
UniqueOpModel<float, int32_t> model({TensorType_FLOAT32, {7}},
|
||||||
|
TensorType_FLOAT32, TensorType_INT32);
|
||||||
|
model.PopulateTensor<float>(model.input_tensor_id(), {5, 5, 5, 5, 5, 5, 5});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({5}));
|
||||||
|
EXPECT_THAT(model.GetIndexesOutput(),
|
||||||
|
ElementsAreArray({0, 0, 0, 0, 0, 0, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UniqueOpModelTest, MultipleElements_SomeDuplicates) {
|
||||||
|
UniqueOpModel<float, int32_t> model({TensorType_FLOAT32, {7}},
|
||||||
|
TensorType_FLOAT32, TensorType_INT32);
|
||||||
|
model.PopulateTensor<float>(model.input_tensor_id(), {2, 3, 5, 7, 2, 7, 3});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 5, 7}));
|
||||||
|
EXPECT_THAT(model.GetIndexesOutput(),
|
||||||
|
ElementsAreArray({0, 1, 2, 3, 0, 3, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UniqueOpModelTest, MultipleElements_SomeDuplicates_IndexInt64) {
|
||||||
|
UniqueOpModel<float, int64_t> model({TensorType_FLOAT32, {7}},
|
||||||
|
TensorType_FLOAT32, TensorType_INT64);
|
||||||
|
model.PopulateTensor<float>(model.input_tensor_id(), {2, 3, 5, 7, 2, 7, 3});
|
||||||
|
model.Invoke();
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({2, 3, 5, 7}));
|
||||||
|
EXPECT_THAT(model.GetIndexesOutput(),
|
||||||
|
ElementsAreArray({0, 1, 2, 3, 0, 3, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite
|
@ -3749,6 +3749,55 @@ def make_placeholder_with_default_tests(zip_path):
|
|||||||
expected_tf_success=3)
|
expected_tf_success=3)
|
||||||
|
|
||||||
|
|
||||||
|
def make_unique_tests(zip_path):
|
||||||
|
"""Make a set of tests for Unique op."""
|
||||||
|
|
||||||
|
test_parameters = [
|
||||||
|
{
|
||||||
|
"input_shape": [[1]],
|
||||||
|
"index_type": [tf.int32, tf.int64, None],
|
||||||
|
"input_values": [3]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input_shape": [[5]],
|
||||||
|
"index_type": [tf.int32, tf.int64],
|
||||||
|
"input_values": [[3, 2, 1, 2, 3]]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input_shape": [[7]],
|
||||||
|
"index_type": [tf.int32, tf.int64],
|
||||||
|
"input_values": [[1, 1, 1, 1, 1, 1, 1]]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"input_shape": [[5]],
|
||||||
|
"index_type": [tf.int32, tf.int64],
|
||||||
|
"input_values": [[3, 2, 1, 0, -1]]
|
||||||
|
}]
|
||||||
|
|
||||||
|
def build_graph(parameters):
|
||||||
|
"""Build the graph for the test case."""
|
||||||
|
|
||||||
|
input_tensor = tf.placeholder(
|
||||||
|
dtype=tf.int32, name="input", shape=parameters["input_shape"])
|
||||||
|
if parameters["index_type"] is None:
|
||||||
|
output = tf.unique(input_tensor)
|
||||||
|
else:
|
||||||
|
output = tf.unique(input_tensor, parameters["index_type"])
|
||||||
|
|
||||||
|
return [input_tensor], output
|
||||||
|
|
||||||
|
def build_inputs(parameters, sess, inputs, outputs):
|
||||||
|
input_values = [create_tensor_data(tf.int32, parameters["input_shape"])]
|
||||||
|
return input_values, sess.run(
|
||||||
|
outputs, feed_dict=dict(zip(inputs, input_values)))
|
||||||
|
|
||||||
|
make_zip_of_tests(
|
||||||
|
zip_path,
|
||||||
|
test_parameters,
|
||||||
|
build_graph,
|
||||||
|
build_inputs,
|
||||||
|
expected_tf_success=9)
|
||||||
|
|
||||||
# Toco binary path provided by the generate rule.
|
# Toco binary path provided by the generate rule.
|
||||||
bin_path = None
|
bin_path = None
|
||||||
|
|
||||||
|
@ -252,6 +252,14 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
|
|||||||
SetDataTypeForAllOutputs(model, op, data_type);
|
SetDataTypeForAllOutputs(model, op, data_type);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case OperatorType::kUnique: {
|
||||||
|
CHECK_EQ(op->outputs.size(), 2);
|
||||||
|
const UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
|
||||||
|
const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
|
||||||
|
model->GetArray(op->outputs[0]).data_type = data_type;
|
||||||
|
model->GetArray(op->outputs[1]).data_type = unique_op->idx_out_type;
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// These operators produce outputs with the same type as their 1st input
|
// These operators produce outputs with the same type as their 1st input
|
||||||
CHECK_GT(op->inputs.size(), 0);
|
CHECK_GT(op->inputs.size(), 0);
|
||||||
|
@ -1828,6 +1828,20 @@ void ProcessMirrorPadOperator(Model* model, MirrorPadOperator* op) {
|
|||||||
output_array.copy_shape(output_shape);
|
output_array.copy_shape(output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
|
||||||
|
const auto& input_array = model->GetArray(op->inputs[0]);
|
||||||
|
// We have 2 outputs, the shape of the index tensor, is the same size
|
||||||
|
// as the input array. The unique values tensor, is unknown until runtime.
|
||||||
|
CHECK_EQ(op->outputs.size(), 2);
|
||||||
|
auto& idx_output_array = model->GetArray(op->outputs[1]);
|
||||||
|
|
||||||
|
// Yield until input dims have been resolved, or output already computed
|
||||||
|
if (!input_array.has_shape() || idx_output_array.has_shape()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
idx_output_array.copy_shape(input_array.shape());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
|
::tensorflow::Status PropagateFixedSizes::Run(Model* model,
|
||||||
@ -2103,6 +2117,9 @@ void ProcessMirrorPadOperator(Model* model, MirrorPadOperator* op) {
|
|||||||
case OperatorType::kMirrorPad:
|
case OperatorType::kMirrorPad:
|
||||||
ProcessMirrorPadOperator(model, static_cast<MirrorPadOperator*>(op));
|
ProcessMirrorPadOperator(model, static_cast<MirrorPadOperator*>(op));
|
||||||
break;
|
break;
|
||||||
|
case OperatorType::kUnique:
|
||||||
|
ProcessUniqueOperator(model, static_cast<UniqueOperator*>(op));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
// Unimplemented, another graph transformation should drop it.
|
// Unimplemented, another graph transformation should drop it.
|
||||||
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
|
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
|
||||||
|
@ -1190,7 +1190,7 @@ enum FlexSupport { kFlexOk, kFlexNotOk };
|
|||||||
// taken from the given NodeDef, and its number must match NumInputs, unless
|
// taken from the given NodeDef, and its number must match NumInputs, unless
|
||||||
// kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator
|
// kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator
|
||||||
// will be eligible for being exported as a flex op.
|
// will be eligible for being exported as a flex op.
|
||||||
template <typename Op, int NumInputs, FlexSupport flex>
|
template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
|
||||||
tensorflow::Status ConvertSimpleOperatorGeneric(
|
tensorflow::Status ConvertSimpleOperatorGeneric(
|
||||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||||
Model* model) {
|
Model* model) {
|
||||||
@ -1203,6 +1203,11 @@ tensorflow::Status ConvertSimpleOperatorGeneric(
|
|||||||
op->inputs.push_back(node.input(i));
|
op->inputs.push_back(node.input(i));
|
||||||
}
|
}
|
||||||
op->outputs.push_back(node.name());
|
op->outputs.push_back(node.name());
|
||||||
|
if (NumOutputs > 1) {
|
||||||
|
for (int i = 1; i < NumOutputs; ++i) {
|
||||||
|
op->outputs.push_back(node.name() + ":" + std::to_string(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (flex == kFlexOk) {
|
if (flex == kFlexOk) {
|
||||||
RetainTensorFlowNodeDef(node, op);
|
RetainTensorFlowNodeDef(node, op);
|
||||||
@ -1213,20 +1218,20 @@ tensorflow::Status ConvertSimpleOperatorGeneric(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Convert a simple operator which is not valid as a flex op.
|
// Convert a simple operator which is not valid as a flex op.
|
||||||
template <typename Op, int NumInputs = kAnyNumInputs>
|
template <typename Op, int NumInputs, int NumOutputs>
|
||||||
tensorflow::Status ConvertSimpleOperator(
|
tensorflow::Status ConvertSimpleOperator(
|
||||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||||
Model* model) {
|
Model* model) {
|
||||||
return ConvertSimpleOperatorGeneric<Op, NumInputs, kFlexNotOk>(
|
return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
|
||||||
node, tf_import_flags, model);
|
node, tf_import_flags, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert a simple operator which is valid as a flex op.
|
// Convert a simple operator which is valid as a flex op.
|
||||||
template <typename Op, int NumInputs = kAnyNumInputs>
|
template <typename Op, int NumInputs, int NumOutputs>
|
||||||
tensorflow::Status ConvertSimpleOperatorFlexOk(
|
tensorflow::Status ConvertSimpleOperatorFlexOk(
|
||||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||||
Model* model) {
|
Model* model) {
|
||||||
return ConvertSimpleOperatorGeneric<Op, NumInputs, kFlexOk>(
|
return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
|
||||||
node, tf_import_flags, model);
|
node, tf_import_flags, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2333,14 +2338,15 @@ ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
|
|||||||
|
|
||||||
ConverterMapType GetTensorFlowNodeConverterMap() {
|
ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||||
return std::unordered_map<std::string, ConverterType>({
|
return std::unordered_map<std::string, ConverterType>({
|
||||||
{"Abs", ConvertSimpleOperator<AbsOperator>},
|
{"Abs", ConvertSimpleOperator<AbsOperator, kAnyNumInputs, 1>},
|
||||||
{"Add", ConvertSimpleOperator<AddOperator, 2>},
|
{"Add", ConvertSimpleOperator<AddOperator, 2, 1>},
|
||||||
{"AddN", ConvertSimpleOperatorFlexOk<AddNOperator>},
|
{"AddN", ConvertSimpleOperatorFlexOk<AddNOperator, kAnyNumInputs, 1>},
|
||||||
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
|
{"All", ConvertSimpleOperator<TensorFlowAllOperator, kAnyNumInputs, 1>},
|
||||||
{"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
|
{"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
|
||||||
{"ArgMax", ConvertArgMaxOperator},
|
{"ArgMax", ConvertArgMaxOperator},
|
||||||
{"ArgMin", ConvertArgMinOperator},
|
{"ArgMin", ConvertArgMinOperator},
|
||||||
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
|
{"Assert",
|
||||||
|
ConvertSimpleOperator<TensorFlowAssertOperator, kAnyNumInputs, 1>},
|
||||||
{"AvgPool", ConvertAvgPoolOperator},
|
{"AvgPool", ConvertAvgPoolOperator},
|
||||||
{"BatchMatMul", ConvertBatchMatMulOperator},
|
{"BatchMatMul", ConvertBatchMatMulOperator},
|
||||||
{"BatchNormWithGlobalNormalization",
|
{"BatchNormWithGlobalNormalization",
|
||||||
@ -2357,98 +2363,99 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
|||||||
{"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
|
{"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
|
||||||
{"DepthToSpace", ConvertDepthToSpaceOperator},
|
{"DepthToSpace", ConvertDepthToSpaceOperator},
|
||||||
{"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
|
{"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
|
||||||
{"Div", ConvertSimpleOperator<DivOperator, 2>},
|
{"Div", ConvertSimpleOperator<DivOperator, 2, 1>},
|
||||||
{"DynamicPartition", ConvertDynamicPartitionOperator},
|
{"DynamicPartition", ConvertDynamicPartitionOperator},
|
||||||
{"DynamicStitch", ConvertDynamicStitchOperator},
|
{"DynamicStitch", ConvertDynamicStitchOperator},
|
||||||
{"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2>},
|
{"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2, 1>},
|
||||||
{"Exp", ConvertSimpleOperator<ExpOperator, 1>},
|
{"Exp", ConvertSimpleOperator<ExpOperator, 1, 1>},
|
||||||
{"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2>},
|
{"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2, 1>},
|
||||||
{"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
|
{"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
|
||||||
{"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
|
{"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
|
||||||
{"Fill", ConvertSimpleOperator<FillOperator, 2>},
|
{"Fill", ConvertSimpleOperator<FillOperator, 2, 1>},
|
||||||
{"Floor", ConvertFloorOperator},
|
{"Floor", ConvertFloorOperator},
|
||||||
{"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2>},
|
{"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2, 1>},
|
||||||
{"FloorMod", ConvertSimpleOperator<FloorModOperator, 2>},
|
{"FloorMod", ConvertSimpleOperator<FloorModOperator, 2, 1>},
|
||||||
{"FusedBatchNorm", ConvertFusedBatchNormOperator},
|
{"FusedBatchNorm", ConvertFusedBatchNormOperator},
|
||||||
{"Gather", ConvertGatherOperator},
|
{"Gather", ConvertGatherOperator},
|
||||||
{"GatherV2", ConvertGatherOperator},
|
{"GatherV2", ConvertGatherOperator},
|
||||||
{"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2>},
|
{"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2, 1>},
|
||||||
{"GreaterEqual",
|
{"GreaterEqual",
|
||||||
ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2>},
|
ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2, 1>},
|
||||||
{"Identity", ConvertIdentityOperator},
|
{"Identity", ConvertIdentityOperator},
|
||||||
{"LRN", ConvertLRNOperator},
|
{"LRN", ConvertLRNOperator},
|
||||||
{"LeakyRelu", ConvertLeakyReluOperator},
|
{"LeakyRelu", ConvertLeakyReluOperator},
|
||||||
{"LegacyFedInput", ConvertPlaceholderOperator},
|
{"LegacyFedInput", ConvertPlaceholderOperator},
|
||||||
{"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>},
|
{"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2, 1>},
|
||||||
{"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>},
|
{"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2, 1>},
|
||||||
{"Log", ConvertSimpleOperator<LogOperator, 1>},
|
{"Log", ConvertSimpleOperator<LogOperator, 1, 1>},
|
||||||
{"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>},
|
{"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2, 1>},
|
||||||
{"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2>},
|
{"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2, 1>},
|
||||||
{"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>},
|
{"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
|
||||||
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
|
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
|
||||||
{"MatMul", ConvertMatMulOperator},
|
{"MatMul", ConvertMatMulOperator},
|
||||||
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
|
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
|
||||||
{"MaxPool", ConvertMaxPoolOperator},
|
{"MaxPool", ConvertMaxPoolOperator},
|
||||||
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>},
|
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
|
||||||
{"Mean", ConvertReduceOperator<MeanOperator>},
|
{"Mean", ConvertReduceOperator<MeanOperator>},
|
||||||
{"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>},
|
{"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2, 1>},
|
||||||
{"Min", ConvertReduceOperator<TensorFlowMinOperator>},
|
{"Min", ConvertReduceOperator<TensorFlowMinOperator>},
|
||||||
{"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>},
|
{"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2, 1>},
|
||||||
{"Mul", ConvertSimpleOperator<MulOperator, 2>},
|
{"Mul", ConvertSimpleOperator<MulOperator, 2, 1>},
|
||||||
{"Neg", ConvertSimpleOperator<NegOperator, 1>},
|
{"Neg", ConvertSimpleOperator<NegOperator, 1, 1>},
|
||||||
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
|
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
|
||||||
{"NoOp", ConvertNoOpOperator},
|
{"NoOp", ConvertNoOpOperator},
|
||||||
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
|
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2, 1>},
|
||||||
{"OneHot", ConvertOneHotOperator},
|
{"OneHot", ConvertOneHotOperator},
|
||||||
{"Pack", ConvertPackOperator},
|
{"Pack", ConvertPackOperator},
|
||||||
{"Pad", ConvertSimpleOperator<PadOperator, 2>},
|
{"Pad", ConvertSimpleOperator<PadOperator, 2, 1>},
|
||||||
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},
|
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3, 1>},
|
||||||
{"ParallelDynamicStitch", ConvertDynamicStitchOperator},
|
{"ParallelDynamicStitch", ConvertDynamicStitchOperator},
|
||||||
{"Placeholder", ConvertPlaceholderOperator},
|
{"Placeholder", ConvertPlaceholderOperator},
|
||||||
{"PlaceholderWithDefault", ConvertIdentityOperator},
|
{"PlaceholderWithDefault", ConvertIdentityOperator},
|
||||||
{"Pow", ConvertSimpleOperator<PowOperator, 2>},
|
{"Pow", ConvertSimpleOperator<PowOperator, 2, 1>},
|
||||||
{"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
|
{"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
|
||||||
{"RandomUniform", ConvertRandomUniform},
|
{"RandomUniform", ConvertRandomUniform},
|
||||||
{"Range", ConvertRangeOperator},
|
{"Range", ConvertRangeOperator},
|
||||||
{"Rank", ConvertSimpleOperator<RankOperator, 1>},
|
{"Rank", ConvertSimpleOperator<RankOperator, 1, 1>},
|
||||||
{"RealDiv", ConvertSimpleOperator<DivOperator, 2>},
|
{"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
|
||||||
{"Relu", ConvertSimpleOperator<ReluOperator, 1>},
|
{"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
|
||||||
{"Relu6", ConvertSimpleOperator<Relu6Operator, 1>},
|
{"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},
|
||||||
{"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2>},
|
{"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
|
||||||
{"ResizeBilinear", ConvertResizeBilinearOperator},
|
{"ResizeBilinear", ConvertResizeBilinearOperator},
|
||||||
{"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
|
{"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
|
||||||
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1>},
|
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
|
||||||
{"Select", ConvertSimpleOperator<SelectOperator, 3>},
|
{"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
|
||||||
{"Shape", ConvertShapeOperator},
|
{"Shape", ConvertShapeOperator},
|
||||||
{"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1>},
|
{"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1, 1>},
|
||||||
{"Sin", ConvertSimpleOperator<SinOperator, 1>},
|
{"Sin", ConvertSimpleOperator<SinOperator, 1, 1>},
|
||||||
{"Slice", ConvertSimpleOperator<SliceOperator, 3>},
|
{"Slice", ConvertSimpleOperator<SliceOperator, 3, 1>},
|
||||||
{"Softmax", ConvertSoftmaxOperator},
|
{"Softmax", ConvertSoftmaxOperator},
|
||||||
{"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
|
{"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
|
||||||
{"SpaceToDepth", ConvertSpaceToDepthOperator},
|
{"SpaceToDepth", ConvertSpaceToDepthOperator},
|
||||||
{"SparseToDense", ConvertSparseToDenseOperator},
|
{"SparseToDense", ConvertSparseToDenseOperator},
|
||||||
{"Split", ConvertSplitOperator},
|
{"Split", ConvertSplitOperator},
|
||||||
{"SplitV", ConvertSplitVOperator},
|
{"SplitV", ConvertSplitVOperator},
|
||||||
{"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1>},
|
{"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1, 1>},
|
||||||
{"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1>},
|
{"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
|
||||||
{"SquaredDifference",
|
{"SquaredDifference",
|
||||||
ConvertSimpleOperator<SquaredDifferenceOperator, 2>},
|
ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
|
||||||
{"Squeeze", ConvertSqueezeOperator},
|
{"Squeeze", ConvertSqueezeOperator},
|
||||||
{"StopGradient", ConvertIdentityOperator},
|
{"StopGradient", ConvertIdentityOperator},
|
||||||
{"StridedSlice", ConvertStridedSliceOperator},
|
{"StridedSlice", ConvertStridedSliceOperator},
|
||||||
{"Sub", ConvertSimpleOperator<SubOperator, 2>},
|
{"Sub", ConvertSimpleOperator<SubOperator, 2, 1>},
|
||||||
{"Sum", ConvertReduceOperator<TensorFlowSumOperator>},
|
{"Sum", ConvertReduceOperator<TensorFlowSumOperator>},
|
||||||
{"Svdf", ConvertSvdfOperator},
|
{"Svdf", ConvertSvdfOperator},
|
||||||
{"Switch", ConvertSwitchOperator},
|
{"Switch", ConvertSwitchOperator},
|
||||||
{"Tanh", ConvertSimpleOperator<TanhOperator, 1>},
|
{"Tanh", ConvertSimpleOperator<TanhOperator, 1, 1>},
|
||||||
{"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2>},
|
{"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2, 1>},
|
||||||
{"TopK", ConvertTopKV2Operator},
|
{"TopK", ConvertTopKV2Operator},
|
||||||
{"TopKV2", ConvertTopKV2Operator},
|
{"TopKV2", ConvertTopKV2Operator},
|
||||||
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
|
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2, 1>},
|
||||||
{"Unpack", ConvertUnpackOperator},
|
{"Unpack", ConvertUnpackOperator},
|
||||||
{"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>},
|
{"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1, 1>},
|
||||||
{"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
|
{"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
|
||||||
{"MirrorPad", ConvertMirrorPadOperator},
|
{"MirrorPad", ConvertMirrorPadOperator},
|
||||||
|
{"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,7 +157,8 @@ enum class OperatorType : uint8 {
|
|||||||
kResizeNearestNeighbor,
|
kResizeNearestNeighbor,
|
||||||
kLeakyRelu,
|
kLeakyRelu,
|
||||||
kAbs,
|
kAbs,
|
||||||
kMirrorPad
|
kMirrorPad,
|
||||||
|
kUnique
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||||
@ -1953,6 +1954,17 @@ struct MirrorPadOperator : Operator {
|
|||||||
MirrorPadMode mode;
|
MirrorPadMode mode;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Unique Operator:
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// inputs[0]: required: the input array
|
||||||
|
//
|
||||||
|
// TensorFlow equivalent: Unique
|
||||||
|
struct UniqueOperator : Operator {
|
||||||
|
UniqueOperator() : Operator(OperatorType::kUnique) {}
|
||||||
|
ArrayDataType idx_out_type = ArrayDataType::kInt32;
|
||||||
|
};
|
||||||
|
|
||||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
|
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
|
||||||
// graph_transformation module.
|
// graph_transformation module.
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
|
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
|
||||||
#include "tensorflow/lite/toco/model.h"
|
#include "tensorflow/lite/toco/model.h"
|
||||||
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
|
#include "tensorflow/lite/toco/tflite/builtin_operator.h"
|
||||||
@ -1478,6 +1479,31 @@ class MirrorPad
|
|||||||
: MirrorPadMode::kSymmetric;
|
: MirrorPadMode::kSymmetric;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int GetVersion(const OperatorSignature& op) const override { return 1; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
|
||||||
|
::tflite::BuiltinOptions_UniqueOptions> {
|
||||||
|
public:
|
||||||
|
using BuiltinOperator::BuiltinOperator;
|
||||||
|
flatbuffers::Offset<TfLiteOptions> WriteOptions(
|
||||||
|
const TocoOperator& op,
|
||||||
|
flatbuffers::FlatBufferBuilder* builder) const override {
|
||||||
|
const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
|
||||||
|
return ::tflite::CreateUniqueOptions(
|
||||||
|
*builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
|
||||||
|
? ::tflite::TensorType::TensorType_INT64
|
||||||
|
: ::tflite::TensorType_INT32);
|
||||||
|
}
|
||||||
|
void ReadOptions(const TfLiteOptions& options,
|
||||||
|
TocoOperator* op) const override {
|
||||||
|
UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
|
||||||
|
unique_op->idx_out_type =
|
||||||
|
options.idx_out_type() == ::tflite::TensorType_INT64
|
||||||
|
? toco::ArrayDataType::kInt64
|
||||||
|
: toco::ArrayDataType::kInt32;
|
||||||
|
}
|
||||||
|
|
||||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -1819,6 +1845,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
|||||||
OperatorType::kSquaredDifference));
|
OperatorType::kSquaredDifference));
|
||||||
ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
|
ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
|
||||||
OperatorType::kMirrorPad));
|
OperatorType::kMirrorPad));
|
||||||
|
ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
|
||||||
|
OperatorType::kUnique));
|
||||||
|
|
||||||
// Custom Operators.
|
// Custom Operators.
|
||||||
ops.push_back(
|
ops.push_back(
|
||||||
|
@ -629,6 +629,15 @@ TEST_F(OperatorTest, BuiltinMirrorPad) {
|
|||||||
EXPECT_EQ(op.mode, output_toco_op->mode);
|
EXPECT_EQ(op.mode, output_toco_op->mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, BuiltinUnique) {
|
||||||
|
UniqueOperator op;
|
||||||
|
op.idx_out_type = ArrayDataType::kInt64;
|
||||||
|
auto output_toco_op =
|
||||||
|
SerializeAndDeserialize(GetOperator("UNIQUE", OperatorType::kUnique), op);
|
||||||
|
ASSERT_NE(nullptr, output_toco_op.get());
|
||||||
|
EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -416,6 +416,7 @@ const char* OperatorTypeName(OperatorType type) {
|
|||||||
HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
|
HANDLE_OPERATORTYPENAME_CASE(LeakyRelu)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(SquaredDifference)
|
HANDLE_OPERATORTYPENAME_CASE(SquaredDifference)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
|
HANDLE_OPERATORTYPENAME_CASE(MirrorPad)
|
||||||
|
HANDLE_OPERATORTYPENAME_CASE(Unique)
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unhandled op type";
|
LOG(FATAL) << "Unhandled op type";
|
||||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||||
|
Loading…
x
Reference in New Issue
Block a user