Add one_hot op support to TFLite
PiperOrigin-RevId: 206185190
This commit is contained in:
parent
0a3155f7fb
commit
6e658c0a5c
@ -248,6 +248,7 @@ def generated_test_models():
|
||||
"mul",
|
||||
"neg",
|
||||
"not_equal",
|
||||
"one_hot",
|
||||
"pack",
|
||||
"pad",
|
||||
"padv2",
|
||||
|
@ -282,6 +282,10 @@ typedef struct {
|
||||
int axis;
|
||||
} TfLitePackParams;
|
||||
|
||||
typedef struct {
|
||||
int axis;
|
||||
} TfLiteOneHotParams;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
@ -110,6 +110,7 @@ typedef enum {
|
||||
kTfLiteBuiltinReduceMax = 82,
|
||||
kTfLiteBuiltinPack = 83,
|
||||
kTfLiteBuiltinLogicalOr = 84,
|
||||
kTfLiteBuiltinOneHot = 85,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -62,6 +62,7 @@ counterparts:
|
||||
* [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) -
|
||||
*as long as tensors are 2D and axis is the last dimension*
|
||||
* [tf.nn.top_k](https://www.tensorflow.org/api_docs/python/tf/nn/top_k)
|
||||
* [tf.one_hot](https://www.tensorflow.org/api_docs/python/tf/one_hot)
|
||||
* [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) - *as long as
|
||||
mode and constant_values are not used*
|
||||
* [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) -
|
||||
|
@ -176,6 +176,7 @@ cc_library(
|
||||
"mfcc.cc",
|
||||
"mul.cc",
|
||||
"neg.cc",
|
||||
"one_hot.cc",
|
||||
"pack.cc",
|
||||
"pad.cc",
|
||||
"pooling.cc",
|
||||
@ -1171,6 +1172,19 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "one_hot_test",
|
||||
size = "small",
|
||||
srcs = ["one_hot_test.cc"],
|
||||
tags = ["tflite_not_portable_ios"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
"//tensorflow/contrib/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
199
tensorflow/contrib/lite/kernels/one_hot.cc
Normal file
199
tensorflow/contrib/lite/kernels/one_hot.cc
Normal file
@ -0,0 +1,199 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/lite/builtin_op_data.h"
|
||||
#include "tensorflow/contrib/lite/context.h"
|
||||
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/contrib/lite/kernels/op_macros.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace one_hot {
|
||||
|
||||
constexpr int kIndicesTensor = 0;
|
||||
constexpr int kDepthTensor = 1;
|
||||
constexpr int kOnValueTensor = 2;
|
||||
constexpr int kOffValueTensor = 3;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
// Convenience utility for destructuring a node into the appropriate tensors and
|
||||
// data for the op. Note that this destructuring is quite cheap, so we can avoid
|
||||
// allocating op-specific, persistent data on the heap.
|
||||
struct OneHotContext {
|
||||
OneHotContext(TfLiteContext* context, TfLiteNode* node) {
|
||||
indices = GetInput(context, node, kIndicesTensor);
|
||||
depth = GetInput(context, node, kDepthTensor);
|
||||
on_value = GetInput(context, node, kOnValueTensor);
|
||||
off_value = GetInput(context, node, kOffValueTensor);
|
||||
output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
const auto* params =
|
||||
reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
|
||||
const int indices_dims = indices->dims->size;
|
||||
axis = (params->axis == -1) ? indices_dims : params->axis;
|
||||
output_dims = indices_dims + 1;
|
||||
dtype = on_value->type;
|
||||
}
|
||||
|
||||
const TfLiteTensor* indices;
|
||||
const TfLiteTensor* depth;
|
||||
const TfLiteTensor* on_value;
|
||||
const TfLiteTensor* off_value;
|
||||
TfLiteTensor* output;
|
||||
int axis;
|
||||
int output_dims;
|
||||
TfLiteType dtype;
|
||||
};
|
||||
|
||||
template <typename T, typename TI>
|
||||
void OneHotComputeImpl(const OneHotContext& op_context) {
|
||||
// prefix_dim_size == # of elements before the axis
|
||||
// depth == # of elements per axis
|
||||
// suffix_dim_size == # of elements after the axis
|
||||
int prefix_dim_size = 1;
|
||||
for (int i = 0; i < op_context.axis; ++i) {
|
||||
prefix_dim_size *= op_context.indices->dims->data[i];
|
||||
}
|
||||
const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
|
||||
const int depth = *op_context.depth->data.i32;
|
||||
|
||||
const T on_value = *GetTensorData<T>(op_context.on_value);
|
||||
const T off_value = *GetTensorData<T>(op_context.off_value);
|
||||
|
||||
// View the indices as a matrix of size:
|
||||
// prefix_dim_size x suffix_dim_size
|
||||
// View the output as a matrix of size:
|
||||
// prefix_dim_size x depth x suffix_dim_size
|
||||
// Then the output is:
|
||||
// output(i, j, k) == (indices(i, k) == j) ? on : off
|
||||
T* output = GetTensorData<T>(op_context.output);
|
||||
const TI* indices = GetTensorData<TI>(op_context.indices);
|
||||
for (int i = 0; i < prefix_dim_size; ++i) {
|
||||
for (int j = 0; j < depth; ++j) {
|
||||
for (int k = 0; k < suffix_dim_size; ++k, ++output) {
|
||||
*output = static_cast<int>(indices[i * suffix_dim_size + k]) == j
|
||||
? on_value
|
||||
: off_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void OneHotCompute(const OneHotContext& op_context) {
|
||||
if (op_context.indices->type == kTfLiteInt64) {
|
||||
OneHotComputeImpl<T, int64_t>(op_context);
|
||||
} else {
|
||||
OneHotComputeImpl<T, int>(op_context);
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
const OneHotContext& op_context) {
|
||||
TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims);
|
||||
for (int i = 0; i < op_context.output_dims; ++i) {
|
||||
if (i < op_context.axis) {
|
||||
output_size->data[i] = op_context.indices->dims->data[i];
|
||||
} else if (i == op_context.axis) {
|
||||
output_size->data[i] = *op_context.depth->data.i32;
|
||||
} else {
|
||||
output_size->data[i] = op_context.indices->dims->data[i - 1];
|
||||
}
|
||||
}
|
||||
return context->ResizeTensor(context, op_context.output, output_size);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
OneHotContext op_context{context, node};
|
||||
switch (op_context.dtype) {
|
||||
// TODO(b/111744875): Support uint8 and quantization.
|
||||
case kTfLiteFloat32:
|
||||
case kTfLiteInt16:
|
||||
case kTfLiteInt32:
|
||||
case kTfLiteInt64:
|
||||
case kTfLiteBool:
|
||||
op_context.output->type = op_context.dtype;
|
||||
break;
|
||||
default:
|
||||
context->ReportError(context, "Unknown output data type: %d",
|
||||
op_context.dtype);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
|
||||
op_context.indices->type == kTfLiteInt64);
|
||||
TF_LITE_ENSURE(context, op_context.axis >= 0 &&
|
||||
op_context.axis < op_context.output_dims);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
|
||||
TF_LITE_ENSURE_EQ(context, op_context.on_value->type, op_context.dtype);
|
||||
TF_LITE_ENSURE_EQ(context, op_context.off_value->type, op_context.dtype);
|
||||
|
||||
if (!IsConstantTensor(op_context.depth)) {
|
||||
SetTensorToDynamic(op_context.output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
return ResizeOutputTensor(context, op_context);
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
OneHotContext op_context{context, node};
|
||||
|
||||
if (IsDynamicTensor(op_context.output)) {
|
||||
ResizeOutputTensor(context, op_context);
|
||||
}
|
||||
|
||||
switch (op_context.output->type) {
|
||||
case kTfLiteFloat32:
|
||||
OneHotCompute<float>(op_context);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
OneHotCompute<int>(op_context);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
OneHotCompute<int64_t>(op_context);
|
||||
break;
|
||||
case kTfLiteBool:
|
||||
OneHotCompute<bool>(op_context);
|
||||
break;
|
||||
default:
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace one_hot
|
||||
|
||||
TfLiteRegistration* Register_ONE_HOT() {
|
||||
static TfLiteRegistration r = {
|
||||
nullptr,
|
||||
nullptr,
|
||||
one_hot::Prepare,
|
||||
one_hot::Eval,
|
||||
};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
182
tensorflow/contrib/lite/kernels/one_hot_test.cc
Normal file
182
tensorflow/contrib/lite/kernels/one_hot_test.cc
Normal file
@ -0,0 +1,182 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <initializer_list>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/contrib/lite/interpreter.h"
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#include "tensorflow/contrib/lite/kernels/test_util.h"
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
template <typename T>
|
||||
class OneHotOpModel : public SingleOpModel {
|
||||
public:
|
||||
OneHotOpModel(std::initializer_list<int> input_shape, int depth_value,
|
||||
TensorType dtype, int axis = -1, T on_value = 1,
|
||||
T off_value = 0, TensorType indices_type = TensorType_INT32) {
|
||||
indices_ = AddInput(indices_type);
|
||||
int depth = AddInput(TensorType_INT32);
|
||||
int on = AddInput(dtype);
|
||||
int off = AddInput(dtype);
|
||||
output_ = AddOutput(dtype);
|
||||
SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions,
|
||||
CreateOneHotOptions(builder_, axis).Union());
|
||||
BuildInterpreter({input_shape});
|
||||
|
||||
PopulateTensor<int>(depth, {depth_value});
|
||||
PopulateTensor<T>(on, {on_value});
|
||||
PopulateTensor<T>(off, {off_value});
|
||||
}
|
||||
|
||||
template <typename TI>
|
||||
void SetIndices(std::initializer_list<TI> data) {
|
||||
PopulateTensor<TI>(indices_, data);
|
||||
}
|
||||
|
||||
TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
|
||||
|
||||
int32_t GetOutputSize() { return GetTensorSize(output_); }
|
||||
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
int indices_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(OneHotOpTest, BasicFloat) {
|
||||
const int depth = 3;
|
||||
OneHotOpModel<float> model({3}, depth, TensorType_FLOAT32);
|
||||
model.SetIndices({0, 1, 2});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}));
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, BasicInt) {
|
||||
const int depth = 3;
|
||||
OneHotOpModel<int> model({3}, depth, TensorType_INT32);
|
||||
model.SetIndices({0, 1, 2});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, BasicBool) {
|
||||
const int depth = 3;
|
||||
OneHotOpModel<bool> model({3}, depth, TensorType_BOOL);
|
||||
model.SetIndices({0, 1, 2});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({true, false, false, false, true, false, false,
|
||||
false, true}));
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, SmallDepth) {
|
||||
const int depth = 1;
|
||||
OneHotOpModel<int> model({3}, depth, TensorType_INT32);
|
||||
model.SetIndices({0, 1, 2});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0}));
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, BigDepth) {
|
||||
const int depth = 4;
|
||||
OneHotOpModel<int> model({2}, depth, TensorType_INT32);
|
||||
model.SetIndices({0, 1});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0}));
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, OnOffValues) {
|
||||
const int depth = 3;
|
||||
const int axis = -1;
|
||||
const int on = 5;
|
||||
const int off = 0;
|
||||
OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
|
||||
model.SetIndices({0, 2, -1, 1});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3}));
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0}));
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, ZeroAxis) {
|
||||
const int depth = 3;
|
||||
const int axis = 0;
|
||||
const int on = 5;
|
||||
const int off = 0;
|
||||
OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
|
||||
model.SetIndices({0, 2, -1, 1});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4}));
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0}));
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, MultiDimensionalIndices) {
|
||||
const int depth = 3;
|
||||
const int axis = -1;
|
||||
const float on = 2;
|
||||
const float off = 0;
|
||||
OneHotOpModel<float> model({2, 2}, depth, TensorType_FLOAT32, axis, on, off);
|
||||
model.SetIndices({0, 2, 1, -1});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3}));
|
||||
EXPECT_THAT(model.GetOutput(),
|
||||
ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0}));
|
||||
}
|
||||
|
||||
TEST(OneHotOpTest, Int64Indices) {
|
||||
const int depth = 3;
|
||||
const int axis = -1;
|
||||
const int on = 1;
|
||||
const int off = 0;
|
||||
OneHotOpModel<int> model({3}, depth, TensorType_INT32, axis, on, off,
|
||||
TensorType_INT64);
|
||||
std::initializer_list<int64_t> indices = {0, 1, 2};
|
||||
model.SetIndices(indices);
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -107,6 +107,7 @@ TfLiteRegistration* Register_SHAPE();
|
||||
TfLiteRegistration* Register_POW();
|
||||
TfLiteRegistration* Register_FAKE_QUANT();
|
||||
TfLiteRegistration* Register_PACK();
|
||||
TfLiteRegistration* Register_ONE_HOT();
|
||||
|
||||
BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
|
||||
@ -197,6 +198,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_POW, Register_POW());
|
||||
AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
|
||||
AddBuiltin(BuiltinOperator_PACK, Register_PACK());
|
||||
AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
@ -730,6 +730,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = static_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_ONE_HOT: {
|
||||
auto* params = MallocPOD<TfLiteOneHotParams>();
|
||||
if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
|
||||
params->axis = schema_params->axis();
|
||||
}
|
||||
*builtin_data = static_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
|
||||
// Below are the ops with no builtin_data strcture.
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||
|
@ -623,6 +623,7 @@ TfLiteStatus AddOpsAndParams(
|
||||
case tflite::BuiltinOperator_FAKE_QUANT:
|
||||
case tflite::BuiltinOperator_PACK:
|
||||
case tflite::BuiltinOperator_LOGICAL_OR:
|
||||
case tflite::BuiltinOperator_ONE_HOT:
|
||||
logError("Op code %d is currently not delegated to NNAPI", builtin);
|
||||
return kTfLiteError;
|
||||
break;
|
||||
|
@ -166,6 +166,7 @@ enum BuiltinOperator : byte {
|
||||
REDUCE_MAX = 82,
|
||||
PACK = 83,
|
||||
LOGICAL_OR = 84,
|
||||
ONE_HOT = 85,
|
||||
}
|
||||
|
||||
// Options for the builtin operators.
|
||||
@ -230,6 +231,7 @@ union BuiltinOptions {
|
||||
FakeQuantOptions,
|
||||
PackOptions,
|
||||
LogicalOrOptions,
|
||||
OneHotOptions,
|
||||
}
|
||||
|
||||
enum Padding : byte { SAME, VALID }
|
||||
@ -549,6 +551,10 @@ table PackOptions {
|
||||
table LogicalOrOptions {
|
||||
}
|
||||
|
||||
table OneHotOptions {
|
||||
axis:int;
|
||||
}
|
||||
|
||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||
// builtin, or a string if the operator is custom.
|
||||
table OperatorCode {
|
||||
|
@ -211,6 +211,9 @@ struct PackOptionsT;
|
||||
struct LogicalOrOptions;
|
||||
struct LogicalOrOptionsT;
|
||||
|
||||
struct OneHotOptions;
|
||||
struct OneHotOptionsT;
|
||||
|
||||
struct OperatorCode;
|
||||
struct OperatorCodeT;
|
||||
|
||||
@ -361,11 +364,12 @@ enum BuiltinOperator {
|
||||
BuiltinOperator_REDUCE_MAX = 82,
|
||||
BuiltinOperator_PACK = 83,
|
||||
BuiltinOperator_LOGICAL_OR = 84,
|
||||
BuiltinOperator_ONE_HOT = 85,
|
||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||
BuiltinOperator_MAX = BuiltinOperator_LOGICAL_OR
|
||||
BuiltinOperator_MAX = BuiltinOperator_ONE_HOT
|
||||
};
|
||||
|
||||
inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] {
|
||||
inline BuiltinOperator (&EnumValuesBuiltinOperator())[85] {
|
||||
static BuiltinOperator values[] = {
|
||||
BuiltinOperator_ADD,
|
||||
BuiltinOperator_AVERAGE_POOL_2D,
|
||||
@ -450,7 +454,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] {
|
||||
BuiltinOperator_REDUCE_PROD,
|
||||
BuiltinOperator_REDUCE_MAX,
|
||||
BuiltinOperator_PACK,
|
||||
BuiltinOperator_LOGICAL_OR
|
||||
BuiltinOperator_LOGICAL_OR,
|
||||
BuiltinOperator_ONE_HOT
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -542,6 +547,7 @@ inline const char **EnumNamesBuiltinOperator() {
|
||||
"REDUCE_MAX",
|
||||
"PACK",
|
||||
"LOGICAL_OR",
|
||||
"ONE_HOT",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -614,11 +620,12 @@ enum BuiltinOptions {
|
||||
BuiltinOptions_FakeQuantOptions = 58,
|
||||
BuiltinOptions_PackOptions = 59,
|
||||
BuiltinOptions_LogicalOrOptions = 60,
|
||||
BuiltinOptions_OneHotOptions = 61,
|
||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||
BuiltinOptions_MAX = BuiltinOptions_LogicalOrOptions
|
||||
BuiltinOptions_MAX = BuiltinOptions_OneHotOptions
|
||||
};
|
||||
|
||||
inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] {
|
||||
inline BuiltinOptions (&EnumValuesBuiltinOptions())[62] {
|
||||
static BuiltinOptions values[] = {
|
||||
BuiltinOptions_NONE,
|
||||
BuiltinOptions_Conv2DOptions,
|
||||
@ -680,7 +687,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] {
|
||||
BuiltinOptions_ArgMinOptions,
|
||||
BuiltinOptions_FakeQuantOptions,
|
||||
BuiltinOptions_PackOptions,
|
||||
BuiltinOptions_LogicalOrOptions
|
||||
BuiltinOptions_LogicalOrOptions,
|
||||
BuiltinOptions_OneHotOptions
|
||||
};
|
||||
return values;
|
||||
}
|
||||
@ -748,6 +756,7 @@ inline const char **EnumNamesBuiltinOptions() {
|
||||
"FakeQuantOptions",
|
||||
"PackOptions",
|
||||
"LogicalOrOptions",
|
||||
"OneHotOptions",
|
||||
nullptr
|
||||
};
|
||||
return names;
|
||||
@ -1002,6 +1011,10 @@ template<> struct BuiltinOptionsTraits<LogicalOrOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions;
|
||||
};
|
||||
|
||||
template<> struct BuiltinOptionsTraits<OneHotOptions> {
|
||||
static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions;
|
||||
};
|
||||
|
||||
struct BuiltinOptionsUnion {
|
||||
BuiltinOptions type;
|
||||
void *value;
|
||||
@ -1513,6 +1526,14 @@ struct BuiltinOptionsUnion {
|
||||
return type == BuiltinOptions_LogicalOrOptions ?
|
||||
reinterpret_cast<const LogicalOrOptionsT *>(value) : nullptr;
|
||||
}
|
||||
OneHotOptionsT *AsOneHotOptions() {
|
||||
return type == BuiltinOptions_OneHotOptions ?
|
||||
reinterpret_cast<OneHotOptionsT *>(value) : nullptr;
|
||||
}
|
||||
const OneHotOptionsT *AsOneHotOptions() const {
|
||||
return type == BuiltinOptions_OneHotOptions ?
|
||||
reinterpret_cast<const OneHotOptionsT *>(value) : nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||
@ -5452,6 +5473,60 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(
|
||||
|
||||
flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OneHotOptionsT : public flatbuffers::NativeTable {
|
||||
typedef OneHotOptions TableType;
|
||||
int32_t axis;
|
||||
OneHotOptionsT()
|
||||
: axis(0) {
|
||||
}
|
||||
};
|
||||
|
||||
struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
typedef OneHotOptionsT NativeTableType;
|
||||
enum {
|
||||
VT_AXIS = 4
|
||||
};
|
||||
int32_t axis() const {
|
||||
return GetField<int32_t>(VT_AXIS, 0);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyField<int32_t>(verifier, VT_AXIS) &&
|
||||
verifier.EndTable();
|
||||
}
|
||||
OneHotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
void UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
static flatbuffers::Offset<OneHotOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
};
|
||||
|
||||
struct OneHotOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
void add_axis(int32_t axis) {
|
||||
fbb_.AddElement<int32_t>(OneHotOptions::VT_AXIS, axis, 0);
|
||||
}
|
||||
explicit OneHotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
}
|
||||
OneHotOptionsBuilder &operator=(const OneHotOptionsBuilder &);
|
||||
flatbuffers::Offset<OneHotOptions> Finish() {
|
||||
const auto end = fbb_.EndTable(start_);
|
||||
auto o = flatbuffers::Offset<OneHotOptions>(end);
|
||||
return o;
|
||||
}
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
int32_t axis = 0) {
|
||||
OneHotOptionsBuilder builder_(_fbb);
|
||||
builder_.add_axis(axis);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
|
||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||
typedef OperatorCode TableType;
|
||||
BuiltinOperator builtin_code;
|
||||
@ -5765,6 +5840,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||
const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast<const LogicalOrOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const OneHotOptions *builtin_options_as_OneHotOptions() const {
|
||||
return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast<const OneHotOptions *>(builtin_options()) : nullptr;
|
||||
}
|
||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||
}
|
||||
@ -6036,6 +6114,10 @@ template<> inline const LogicalOrOptions *Operator::builtin_options_as<LogicalOr
|
||||
return builtin_options_as_LogicalOrOptions();
|
||||
}
|
||||
|
||||
template<> inline const OneHotOptions *Operator::builtin_options_as<OneHotOptions>() const {
|
||||
return builtin_options_as_OneHotOptions();
|
||||
}
|
||||
|
||||
struct OperatorBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
@ -8151,6 +8233,32 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers:
|
||||
_fbb);
|
||||
}
|
||||
|
||||
inline OneHotOptionsT *OneHotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OneHotOptionsT();
|
||||
UnPackTo(_o, _resolver);
|
||||
return _o;
|
||||
}
|
||||
|
||||
inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
{ auto _e = axis(); _o->axis = _e; };
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<OneHotOptions> OneHotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
return CreateOneHotOptions(_fbb, _o, _rehasher);
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||
(void)_rehasher;
|
||||
(void)_o;
|
||||
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OneHotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||
auto _axis = _o->axis;
|
||||
return tflite::CreateOneHotOptions(
|
||||
_fbb,
|
||||
_axis);
|
||||
}
|
||||
|
||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||
auto _o = new OperatorCodeT();
|
||||
UnPackTo(_o, _resolver);
|
||||
@ -8580,6 +8688,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
||||
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
case BuiltinOptions_OneHotOptions: {
|
||||
auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
|
||||
return verifier.VerifyTable(ptr);
|
||||
}
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
@ -8838,6 +8950,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
||||
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
case BuiltinOptions_OneHotOptions: {
|
||||
auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
|
||||
return ptr->UnPack(resolver);
|
||||
}
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
@ -9084,6 +9200,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
||||
auto ptr = reinterpret_cast<const LogicalOrOptionsT *>(value);
|
||||
return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
case BuiltinOptions_OneHotOptions: {
|
||||
auto ptr = reinterpret_cast<const OneHotOptionsT *>(value);
|
||||
return CreateOneHotOptions(_fbb, ptr, _rehasher).Union();
|
||||
}
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
@ -9330,6 +9450,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
||||
value = new LogicalOrOptionsT(*reinterpret_cast<LogicalOrOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_OneHotOptions: {
|
||||
value = new OneHotOptionsT(*reinterpret_cast<OneHotOptionsT *>(u.value));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -9637,6 +9761,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
case BuiltinOptions_OneHotOptions: {
|
||||
auto ptr = reinterpret_cast<OneHotOptionsT *>(value);
|
||||
delete ptr;
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
value = nullptr;
|
||||
|
@ -242,7 +242,9 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
|
||||
value = (max_value-min_value)*np.random.random_sample(shape)+min_value
|
||||
elif dtype in (tf.int32, tf.uint8, tf.int64):
|
||||
value = np.random.randint(min_value, max_value+1, shape)
|
||||
return value.astype(dtype)
|
||||
|
||||
return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
|
||||
dtype)
|
||||
|
||||
|
||||
def create_scalar_data(dtype, min_value=-100, max_value=100):
|
||||
@ -1665,6 +1667,65 @@ def make_shape_tests(zip_path):
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_one_hot_tests(zip_path):
|
||||
"""Make a set of tests to do one_hot."""
|
||||
|
||||
test_parameters = [{
|
||||
"indices_type": [tf.int32, tf.int64],
|
||||
"indices_shape": [[3], [4, 4], [1, 5], [5, 1]],
|
||||
"axis": [0, 1],
|
||||
"dtype": [tf.int32, tf.int64, tf.float32],
|
||||
"provide_optional_inputs": [True, False],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
indices = tf.placeholder(
|
||||
dtype=parameters["indices_type"],
|
||||
name="indices",
|
||||
shape=parameters["indices_shape"])
|
||||
depth = tf.placeholder(dtype=tf.int32, name="depth", shape=())
|
||||
|
||||
if not parameters["provide_optional_inputs"]:
|
||||
out = tf.one_hot(indices=indices, depth=depth)
|
||||
return [indices, depth], [out]
|
||||
|
||||
on_value = tf.placeholder(
|
||||
dtype=parameters["dtype"], name="on_value", shape=())
|
||||
off_value = tf.placeholder(
|
||||
dtype=parameters["dtype"], name="off_value", shape=())
|
||||
out = tf.one_hot(
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
on_value=on_value,
|
||||
off_value=off_value,
|
||||
axis=parameters["axis"],
|
||||
dtype=parameters["dtype"])
|
||||
return [indices, depth, on_value, off_value], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_values = [
|
||||
create_tensor_data(
|
||||
parameters["indices_type"],
|
||||
shape=parameters["indices_shape"],
|
||||
min_value=-1,
|
||||
max_value=10),
|
||||
create_tensor_data(tf.int32, shape=None, min_value=1, max_value=10),
|
||||
]
|
||||
|
||||
if parameters["provide_optional_inputs"]:
|
||||
input_values.append(
|
||||
create_tensor_data(
|
||||
parameters["dtype"], shape=None, min_value=1, max_value=10))
|
||||
input_values.append(
|
||||
create_tensor_data(
|
||||
parameters["dtype"], shape=None, min_value=-1, max_value=0))
|
||||
|
||||
return input_values, sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, input_values)))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
||||
def make_resize_bilinear_tests(zip_path):
|
||||
"""Make a set of tests to do resize_bilinear."""
|
||||
|
||||
|
@ -1316,6 +1316,20 @@ void ConvertResizeBilinearOperator(const Model& model,
|
||||
(*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
|
||||
}
|
||||
|
||||
void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
|
||||
GraphDef* tensorflow_graph) {
|
||||
tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
|
||||
onehot_op->set_op("OneHot");
|
||||
onehot_op->set_name(src_op.outputs[0]);
|
||||
CHECK_EQ(src_op.inputs.size(), 4);
|
||||
for (const auto& input : src_op.inputs) {
|
||||
*onehot_op->add_input() = input;
|
||||
}
|
||||
(*onehot_op->mutable_attr())["T"].set_type(
|
||||
GetTensorFlowDataType(model, src_op.outputs[0]));
|
||||
(*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// TODO(aselle): Remove when available in absl
|
||||
absl::string_view FindLongestCommonPrefix(absl::string_view a,
|
||||
@ -2158,6 +2172,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
|
||||
ConvertLogicalNotOperator(model,
|
||||
static_cast<const LogicalNotOperator&>(src_op),
|
||||
tensorflow_graph);
|
||||
} else if (src_op.type == OperatorType::kOneHot) {
|
||||
ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
|
||||
tensorflow_graph);
|
||||
} else {
|
||||
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
|
||||
}
|
||||
|
@ -201,6 +201,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
|
||||
SetDataTypeForAllOutputs(model, op, data_type);
|
||||
break;
|
||||
}
|
||||
case OperatorType::kOneHot: {
|
||||
CHECK_EQ(op->inputs.size(), 4);
|
||||
CHECK_EQ(op->outputs.size(), 1);
|
||||
const ArrayDataType on_value_type =
|
||||
model->GetArray(op->inputs[OneHotOperator::ON_VALUE_INPUT]).data_type;
|
||||
const ArrayDataType off_value_type =
|
||||
model->GetArray(op->inputs[OneHotOperator::OFF_VALUE_INPUT])
|
||||
.data_type;
|
||||
CHECK(on_value_type == off_value_type);
|
||||
model->GetArray(op->outputs[0]).data_type = on_value_type;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// These operators produce outputs with the same type as their 1st input
|
||||
CHECK_GT(op->inputs.size(), 0);
|
||||
|
@ -1578,6 +1578,61 @@ void ProcessAnyOperator(Model* model, AnyOperator* op) {
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
|
||||
CHECK_EQ(op->inputs.size(), 4);
|
||||
CHECK_EQ(op->outputs.size(), 1);
|
||||
auto& output_array = model->GetArray(op->outputs[0]);
|
||||
if (output_array.has_shape()) {
|
||||
// Shape already propagated
|
||||
return;
|
||||
}
|
||||
|
||||
// Yield until indices dims have been resolved.
|
||||
const auto& indices_array =
|
||||
model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
|
||||
if (!indices_array.has_shape()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Yield until depth is constant and dims have been resolved.
|
||||
if (!IsConstantParameterArray(*model,
|
||||
op->inputs[OneHotOperator::DEPTH_INPUT])) {
|
||||
return;
|
||||
}
|
||||
const auto& depth_array =
|
||||
model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
|
||||
if (!depth_array.has_shape()) {
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK(depth_array.data_type == ArrayDataType::kInt32)
|
||||
<< "Depth array must be int32.";
|
||||
CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
|
||||
<< "Depth array must be scalar.";
|
||||
|
||||
const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
|
||||
CHECK_GE(depth, 0) << "Depth must be non-negative.";
|
||||
|
||||
const int indices_dims = indices_array.shape().dimensions_count();
|
||||
const int output_dims = indices_dims + 1;
|
||||
const int axis = op->axis == -1 ? indices_dims : op->axis;
|
||||
CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
|
||||
|
||||
auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
|
||||
mutable_dims->resize(output_dims);
|
||||
for (int i = 0; i < output_dims; ++i) {
|
||||
int dim = 0;
|
||||
if (i < axis) {
|
||||
dim = indices_array.shape().dims(i);
|
||||
} else if (i == axis) {
|
||||
dim = depth;
|
||||
} else {
|
||||
dim = indices_array.shape().dims(i - 1);
|
||||
}
|
||||
(*mutable_dims)[i] = dim;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
|
||||
@ -1825,6 +1880,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
|
||||
case OperatorType::kAny:
|
||||
ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
|
||||
break;
|
||||
case OperatorType::kOneHot:
|
||||
ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
|
||||
break;
|
||||
default:
|
||||
// Unimplemented, another graph transformation should drop it.
|
||||
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
|
||||
|
@ -1833,6 +1833,27 @@ tensorflow::Status ConvertSparseToDenseOperator(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status ConvertOneHotOperator(
|
||||
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||
Model* model) {
|
||||
CHECK_EQ(node.op(), "OneHot");
|
||||
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
|
||||
|
||||
const auto dtype = GetDataTypeAttr(node, "T");
|
||||
// TODO(b/111744875): Support DT_UINT8 and quantization.
|
||||
CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
|
||||
dtype == DT_BOOL);
|
||||
|
||||
auto op = absl::make_unique<OneHotOperator>();
|
||||
op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
|
||||
for (const string& input : node.input()) {
|
||||
op->inputs.push_back(input);
|
||||
}
|
||||
op->outputs.push_back(node.name());
|
||||
model->operators.emplace_back(op.release());
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace internal {
|
||||
@ -1909,6 +1930,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
|
||||
{"NoOp", ConvertNoOpOperator},
|
||||
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
|
||||
{"OneHot", ConvertOneHotOperator},
|
||||
{"Pack", ConvertPackOperator},
|
||||
{"Pad", ConvertSimpleOperator<PadOperator, 2>},
|
||||
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},
|
||||
|
@ -64,6 +64,7 @@ enum class OperatorType : uint8 {
|
||||
kMaxPool,
|
||||
kFakeQuant,
|
||||
kMul,
|
||||
kOneHot,
|
||||
kRandomUniform,
|
||||
kRange,
|
||||
kRank,
|
||||
@ -1768,6 +1769,27 @@ struct LogicalNotOperator : Operator {
|
||||
LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {}
|
||||
};
|
||||
|
||||
// OneHot operator:
|
||||
//
|
||||
// Inputs:
|
||||
// Inputs[0]: required: indices.
|
||||
// Inputs[1]: required: depth.
|
||||
// Inputs[2]: required: on_value.
|
||||
// Inputs[3]: required: off_value.
|
||||
//
|
||||
// TensorFlow equivalent: OneHot.
|
||||
struct OneHotOperator : Operator {
|
||||
enum Inputs {
|
||||
INDICES_INPUT = 0,
|
||||
DEPTH_INPUT = 1,
|
||||
ON_VALUE_INPUT = 2,
|
||||
OFF_VALUE_INPUT = 3,
|
||||
};
|
||||
|
||||
OneHotOperator() : Operator(OperatorType::kOneHot) {}
|
||||
int axis = -1;
|
||||
};
|
||||
|
||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||
|
@ -1053,6 +1053,23 @@ class Shape
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
};
|
||||
|
||||
class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
|
||||
::tflite::BuiltinOptions_OneHotOptions> {
|
||||
public:
|
||||
using BuiltinOperator::BuiltinOperator;
|
||||
flatbuffers::Offset<TfLiteOptions> WriteOptions(
|
||||
const TocoOperator& op,
|
||||
flatbuffers::FlatBufferBuilder* builder) const override {
|
||||
return ::tflite::CreateOneHotOptions(*builder, op.axis);
|
||||
}
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {
|
||||
op->axis = options.axis();
|
||||
}
|
||||
|
||||
int GetVersion(const Operator& op) const override { return 1; }
|
||||
};
|
||||
|
||||
class TensorFlowUnsupported : public BaseOperator {
|
||||
public:
|
||||
using BaseOperator::BaseOperator;
|
||||
@ -1278,6 +1295,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
|
||||
OperatorType::kFakeQuant));
|
||||
ops.emplace_back(
|
||||
new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
|
||||
ops.emplace_back(
|
||||
new OneHot(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot));
|
||||
|
||||
// Custom Operators.
|
||||
ops.emplace_back(
|
||||
|
@ -462,6 +462,14 @@ TEST_F(OperatorTest, BuiltinPack) {
|
||||
EXPECT_EQ(op.axis, output_toco_op->axis);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinOneHot) {
|
||||
OneHotOperator op;
|
||||
op.axis = 2;
|
||||
auto output_toco_op = SerializeAndDeserialize(
|
||||
GetOperator("ONE_HOT", OperatorType::kOneHot), op);
|
||||
EXPECT_EQ(op.axis, output_toco_op->axis);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, TensorFlowUnsupported) {
|
||||
TensorFlowUnsupportedOperator op;
|
||||
op.tensorflow_op = "MyCustomUnsupportedOp";
|
||||
|
@ -356,6 +356,7 @@ const char* OperatorTypeName(OperatorType type) {
|
||||
HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
|
||||
HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
|
||||
HANDLE_OPERATORTYPENAME_CASE(Neg)
|
||||
HANDLE_OPERATORTYPENAME_CASE(OneHot)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Pack)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Pad)
|
||||
HANDLE_OPERATORTYPENAME_CASE(PadV2)
|
||||
|
Loading…
Reference in New Issue
Block a user