Add one_hot op support to TFLite

PiperOrigin-RevId: 206185190
This commit is contained in:
Jared Duke 2018-07-26 10:53:21 -07:00 committed by TensorFlower Gardener
parent 0a3155f7fb
commit 6e658c0a5c
21 changed files with 775 additions and 7 deletions

View File

@ -248,6 +248,7 @@ def generated_test_models():
"mul",
"neg",
"not_equal",
"one_hot",
"pack",
"pad",
"padv2",

View File

@ -282,6 +282,10 @@ typedef struct {
int axis;
} TfLitePackParams;
typedef struct {
int axis;
} TfLiteOneHotParams;
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus

View File

@ -110,6 +110,7 @@ typedef enum {
kTfLiteBuiltinReduceMax = 82,
kTfLiteBuiltinPack = 83,
kTfLiteBuiltinLogicalOr = 84,
kTfLiteBuiltinOneHot = 85,
} TfLiteBuiltinOperator;
#ifdef __cplusplus

View File

@ -62,6 +62,7 @@ counterparts:
* [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) -
*as long as tensors are 2D and axis is the last dimension*
* [tf.nn.top_k](https://www.tensorflow.org/api_docs/python/tf/nn/top_k)
* [tf.one_hot](https://www.tensorflow.org/api_docs/python/tf/one_hot)
* [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) - *as long as
mode and constant_values are not used*
* [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) -

View File

@ -176,6 +176,7 @@ cc_library(
"mfcc.cc",
"mul.cc",
"neg.cc",
"one_hot.cc",
"pack.cc",
"pad.cc",
"pooling.cc",
@ -1171,6 +1172,19 @@ tf_cc_test(
],
)
tf_cc_test(
name = "one_hot_test",
size = "small",
srcs = ["one_hot_test.cc"],
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -0,0 +1,199 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace one_hot {
constexpr int kIndicesTensor = 0;
constexpr int kDepthTensor = 1;
constexpr int kOnValueTensor = 2;
constexpr int kOffValueTensor = 3;
constexpr int kOutputTensor = 0;
// Convenience utility for destructuring a node into the appropriate tensors and
// data for the op. Note that this destructuring is quite cheap, so we can avoid
// allocating op-specific, persistent data on the heap.
struct OneHotContext {
OneHotContext(TfLiteContext* context, TfLiteNode* node) {
indices = GetInput(context, node, kIndicesTensor);
depth = GetInput(context, node, kDepthTensor);
on_value = GetInput(context, node, kOnValueTensor);
off_value = GetInput(context, node, kOffValueTensor);
output = GetOutput(context, node, kOutputTensor);
const auto* params =
reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
const int indices_dims = indices->dims->size;
axis = (params->axis == -1) ? indices_dims : params->axis;
output_dims = indices_dims + 1;
dtype = on_value->type;
}
const TfLiteTensor* indices;
const TfLiteTensor* depth;
const TfLiteTensor* on_value;
const TfLiteTensor* off_value;
TfLiteTensor* output;
int axis;
int output_dims;
TfLiteType dtype;
};
template <typename T, typename TI>
void OneHotComputeImpl(const OneHotContext& op_context) {
// prefix_dim_size == # of elements before the axis
// depth == # of elements per axis
// suffix_dim_size == # of elements after the axis
int prefix_dim_size = 1;
for (int i = 0; i < op_context.axis; ++i) {
prefix_dim_size *= op_context.indices->dims->data[i];
}
const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
const int depth = *op_context.depth->data.i32;
const T on_value = *GetTensorData<T>(op_context.on_value);
const T off_value = *GetTensorData<T>(op_context.off_value);
// View the indices as a matrix of size:
// prefix_dim_size x suffix_dim_size
// View the output as a matrix of size:
// prefix_dim_size x depth x suffix_dim_size
// Then the output is:
// output(i, j, k) == (indices(i, k) == j) ? on : off
T* output = GetTensorData<T>(op_context.output);
const TI* indices = GetTensorData<TI>(op_context.indices);
for (int i = 0; i < prefix_dim_size; ++i) {
for (int j = 0; j < depth; ++j) {
for (int k = 0; k < suffix_dim_size; ++k, ++output) {
*output = static_cast<int>(indices[i * suffix_dim_size + k]) == j
? on_value
: off_value;
}
}
}
}
template <typename T>
void OneHotCompute(const OneHotContext& op_context) {
if (op_context.indices->type == kTfLiteInt64) {
OneHotComputeImpl<T, int64_t>(op_context);
} else {
OneHotComputeImpl<T, int>(op_context);
}
}
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
const OneHotContext& op_context) {
TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims);
for (int i = 0; i < op_context.output_dims; ++i) {
if (i < op_context.axis) {
output_size->data[i] = op_context.indices->dims->data[i];
} else if (i == op_context.axis) {
output_size->data[i] = *op_context.depth->data.i32;
} else {
output_size->data[i] = op_context.indices->dims->data[i - 1];
}
}
return context->ResizeTensor(context, op_context.output, output_size);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
OneHotContext op_context{context, node};
switch (op_context.dtype) {
// TODO(b/111744875): Support uint8 and quantization.
case kTfLiteFloat32:
case kTfLiteInt16:
case kTfLiteInt32:
case kTfLiteInt64:
case kTfLiteBool:
op_context.output->type = op_context.dtype;
break;
default:
context->ReportError(context, "Unknown output data type: %d",
op_context.dtype);
return kTfLiteError;
}
TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
op_context.indices->type == kTfLiteInt64);
TF_LITE_ENSURE(context, op_context.axis >= 0 &&
op_context.axis < op_context.output_dims);
TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
TF_LITE_ENSURE_EQ(context, op_context.on_value->type, op_context.dtype);
TF_LITE_ENSURE_EQ(context, op_context.off_value->type, op_context.dtype);
if (!IsConstantTensor(op_context.depth)) {
SetTensorToDynamic(op_context.output);
return kTfLiteOk;
}
return ResizeOutputTensor(context, op_context);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
OneHotContext op_context{context, node};
if (IsDynamicTensor(op_context.output)) {
ResizeOutputTensor(context, op_context);
}
switch (op_context.output->type) {
case kTfLiteFloat32:
OneHotCompute<float>(op_context);
break;
case kTfLiteInt32:
OneHotCompute<int>(op_context);
break;
case kTfLiteInt64:
OneHotCompute<int64_t>(op_context);
break;
case kTfLiteBool:
OneHotCompute<bool>(op_context);
break;
default:
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace one_hot
TfLiteRegistration* Register_ONE_HOT() {
static TfLiteRegistration r = {
nullptr,
nullptr,
one_hot::Prepare,
one_hot::Eval,
};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,182 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <initializer_list>
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
template <typename T>
class OneHotOpModel : public SingleOpModel {
public:
OneHotOpModel(std::initializer_list<int> input_shape, int depth_value,
TensorType dtype, int axis = -1, T on_value = 1,
T off_value = 0, TensorType indices_type = TensorType_INT32) {
indices_ = AddInput(indices_type);
int depth = AddInput(TensorType_INT32);
int on = AddInput(dtype);
int off = AddInput(dtype);
output_ = AddOutput(dtype);
SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions,
CreateOneHotOptions(builder_, axis).Union());
BuildInterpreter({input_shape});
PopulateTensor<int>(depth, {depth_value});
PopulateTensor<T>(on, {on_value});
PopulateTensor<T>(off, {off_value});
}
template <typename TI>
void SetIndices(std::initializer_list<TI> data) {
PopulateTensor<TI>(indices_, data);
}
TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
int32_t GetOutputSize() { return GetTensorSize(output_); }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int indices_;
int output_;
};
TEST(OneHotOpTest, BasicFloat) {
const int depth = 3;
OneHotOpModel<float> model({3}, depth, TensorType_FLOAT32);
model.SetIndices({0, 1, 2});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}));
}
TEST(OneHotOpTest, BasicInt) {
const int depth = 3;
OneHotOpModel<int> model({3}, depth, TensorType_INT32);
model.SetIndices({0, 1, 2});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
}
TEST(OneHotOpTest, BasicBool) {
const int depth = 3;
OneHotOpModel<bool> model({3}, depth, TensorType_BOOL);
model.SetIndices({0, 1, 2});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({true, false, false, false, true, false, false,
false, true}));
}
TEST(OneHotOpTest, SmallDepth) {
const int depth = 1;
OneHotOpModel<int> model({3}, depth, TensorType_INT32);
model.SetIndices({0, 1, 2});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0}));
}
TEST(OneHotOpTest, BigDepth) {
const int depth = 4;
OneHotOpModel<int> model({2}, depth, TensorType_INT32);
model.SetIndices({0, 1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0}));
}
TEST(OneHotOpTest, OnOffValues) {
const int depth = 3;
const int axis = -1;
const int on = 5;
const int off = 0;
OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
model.SetIndices({0, 2, -1, 1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3}));
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0}));
}
TEST(OneHotOpTest, ZeroAxis) {
const int depth = 3;
const int axis = 0;
const int on = 5;
const int off = 0;
OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
model.SetIndices({0, 2, -1, 1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4}));
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0}));
}
TEST(OneHotOpTest, MultiDimensionalIndices) {
const int depth = 3;
const int axis = -1;
const float on = 2;
const float off = 0;
OneHotOpModel<float> model({2, 2}, depth, TensorType_FLOAT32, axis, on, off);
model.SetIndices({0, 2, 1, -1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3}));
EXPECT_THAT(model.GetOutput(),
ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0}));
}
TEST(OneHotOpTest, Int64Indices) {
const int depth = 3;
const int axis = -1;
const int on = 1;
const int off = 0;
OneHotOpModel<int> model({3}, depth, TensorType_INT32, axis, on, off,
TensorType_INT64);
std::initializer_list<int64_t> indices = {0, 1, 2};
model.SetIndices(indices);
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -107,6 +107,7 @@ TfLiteRegistration* Register_SHAPE();
TfLiteRegistration* Register_POW();
TfLiteRegistration* Register_FAKE_QUANT();
TfLiteRegistration* Register_PACK();
TfLiteRegistration* Register_ONE_HOT();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@ -197,6 +198,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_POW, Register_POW());
AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
AddBuiltin(BuiltinOperator_PACK, Register_PACK());
AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.

View File

@ -730,6 +730,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = static_cast<void*>(params);
break;
}
case BuiltinOperator_ONE_HOT: {
auto* params = MallocPOD<TfLiteOneHotParams>();
if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
params->axis = schema_params->axis();
}
*builtin_data = static_cast<void*>(params);
break;
}
// Below are the ops with no builtin_data strcture.
case BuiltinOperator_BATCH_TO_SPACE_ND:

View File

@ -623,6 +623,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_FAKE_QUANT:
case tflite::BuiltinOperator_PACK:
case tflite::BuiltinOperator_LOGICAL_OR:
case tflite::BuiltinOperator_ONE_HOT:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;

View File

@ -166,6 +166,7 @@ enum BuiltinOperator : byte {
REDUCE_MAX = 82,
PACK = 83,
LOGICAL_OR = 84,
ONE_HOT = 85,
}
// Options for the builtin operators.
@ -230,6 +231,7 @@ union BuiltinOptions {
FakeQuantOptions,
PackOptions,
LogicalOrOptions,
OneHotOptions,
}
enum Padding : byte { SAME, VALID }
@ -549,6 +551,10 @@ table PackOptions {
table LogicalOrOptions {
}
table OneHotOptions {
axis:int;
}
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {

View File

@ -211,6 +211,9 @@ struct PackOptionsT;
struct LogicalOrOptions;
struct LogicalOrOptionsT;
struct OneHotOptions;
struct OneHotOptionsT;
struct OperatorCode;
struct OperatorCodeT;
@ -361,11 +364,12 @@ enum BuiltinOperator {
BuiltinOperator_REDUCE_MAX = 82,
BuiltinOperator_PACK = 83,
BuiltinOperator_LOGICAL_OR = 84,
BuiltinOperator_ONE_HOT = 85,
BuiltinOperator_MIN = BuiltinOperator_ADD,
BuiltinOperator_MAX = BuiltinOperator_LOGICAL_OR
BuiltinOperator_MAX = BuiltinOperator_ONE_HOT
};
inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] {
inline BuiltinOperator (&EnumValuesBuiltinOperator())[85] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@ -450,7 +454,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] {
BuiltinOperator_REDUCE_PROD,
BuiltinOperator_REDUCE_MAX,
BuiltinOperator_PACK,
BuiltinOperator_LOGICAL_OR
BuiltinOperator_LOGICAL_OR,
BuiltinOperator_ONE_HOT
};
return values;
}
@ -542,6 +547,7 @@ inline const char **EnumNamesBuiltinOperator() {
"REDUCE_MAX",
"PACK",
"LOGICAL_OR",
"ONE_HOT",
nullptr
};
return names;
@ -614,11 +620,12 @@ enum BuiltinOptions {
BuiltinOptions_FakeQuantOptions = 58,
BuiltinOptions_PackOptions = 59,
BuiltinOptions_LogicalOrOptions = 60,
BuiltinOptions_OneHotOptions = 61,
BuiltinOptions_MIN = BuiltinOptions_NONE,
BuiltinOptions_MAX = BuiltinOptions_LogicalOrOptions
BuiltinOptions_MAX = BuiltinOptions_OneHotOptions
};
inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] {
inline BuiltinOptions (&EnumValuesBuiltinOptions())[62] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@ -680,7 +687,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] {
BuiltinOptions_ArgMinOptions,
BuiltinOptions_FakeQuantOptions,
BuiltinOptions_PackOptions,
BuiltinOptions_LogicalOrOptions
BuiltinOptions_LogicalOrOptions,
BuiltinOptions_OneHotOptions
};
return values;
}
@ -748,6 +756,7 @@ inline const char **EnumNamesBuiltinOptions() {
"FakeQuantOptions",
"PackOptions",
"LogicalOrOptions",
"OneHotOptions",
nullptr
};
return names;
@ -1002,6 +1011,10 @@ template<> struct BuiltinOptionsTraits<LogicalOrOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions;
};
template<> struct BuiltinOptionsTraits<OneHotOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions;
};
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@ -1513,6 +1526,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_LogicalOrOptions ?
reinterpret_cast<const LogicalOrOptionsT *>(value) : nullptr;
}
OneHotOptionsT *AsOneHotOptions() {
return type == BuiltinOptions_OneHotOptions ?
reinterpret_cast<OneHotOptionsT *>(value) : nullptr;
}
const OneHotOptionsT *AsOneHotOptions() const {
return type == BuiltinOptions_OneHotOptions ?
reinterpret_cast<const OneHotOptionsT *>(value) : nullptr;
}
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@ -5452,6 +5473,60 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(
flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OneHotOptionsT : public flatbuffers::NativeTable {
typedef OneHotOptions TableType;
int32_t axis;
OneHotOptionsT()
: axis(0) {
}
};
struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef OneHotOptionsT NativeTableType;
enum {
VT_AXIS = 4
};
int32_t axis() const {
return GetField<int32_t>(VT_AXIS, 0);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_AXIS) &&
verifier.EndTable();
}
OneHotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<OneHotOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct OneHotOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_axis(int32_t axis) {
fbb_.AddElement<int32_t>(OneHotOptions::VT_AXIS, axis, 0);
}
explicit OneHotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
OneHotOptionsBuilder &operator=(const OneHotOptionsBuilder &);
flatbuffers::Offset<OneHotOptions> Finish() {
const auto end = fbb_.EndTable(start_);
auto o = flatbuffers::Offset<OneHotOptions>(end);
return o;
}
};
inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(
flatbuffers::FlatBufferBuilder &_fbb,
int32_t axis = 0) {
OneHotOptionsBuilder builder_(_fbb);
builder_.add_axis(axis);
return builder_.Finish();
}
flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@ -5765,6 +5840,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const {
return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast<const LogicalOrOptions *>(builtin_options()) : nullptr;
}
const OneHotOptions *builtin_options_as_OneHotOptions() const {
return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast<const OneHotOptions *>(builtin_options()) : nullptr;
}
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@ -6036,6 +6114,10 @@ template<> inline const LogicalOrOptions *Operator::builtin_options_as<LogicalOr
return builtin_options_as_LogicalOrOptions();
}
template<> inline const OneHotOptions *Operator::builtin_options_as<OneHotOptions>() const {
return builtin_options_as_OneHotOptions();
}
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@ -8151,6 +8233,32 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers:
_fbb);
}
inline OneHotOptionsT *OneHotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OneHotOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
{ auto _e = axis(); _o->axis = _e; };
}
inline flatbuffers::Offset<OneHotOptions> OneHotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
return CreateOneHotOptions(_fbb, _o, _rehasher);
}
inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OneHotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _axis = _o->axis;
return tflite::CreateOneHotOptions(
_fbb,
_axis);
}
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@ -8580,6 +8688,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_OneHotOptions: {
auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
return verifier.VerifyTable(ptr);
}
default: return false;
}
}
@ -8838,6 +8950,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_OneHotOptions: {
auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
return ptr->UnPack(resolver);
}
default: return nullptr;
}
}
@ -9084,6 +9200,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const LogicalOrOptionsT *>(value);
return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_OneHotOptions: {
auto ptr = reinterpret_cast<const OneHotOptionsT *>(value);
return CreateOneHotOptions(_fbb, ptr, _rehasher).Union();
}
default: return 0;
}
}
@ -9330,6 +9450,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new LogicalOrOptionsT(*reinterpret_cast<LogicalOrOptionsT *>(u.value));
break;
}
case BuiltinOptions_OneHotOptions: {
value = new OneHotOptionsT(*reinterpret_cast<OneHotOptionsT *>(u.value));
break;
}
default:
break;
}
@ -9637,6 +9761,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
case BuiltinOptions_OneHotOptions: {
auto ptr = reinterpret_cast<OneHotOptionsT *>(value);
delete ptr;
break;
}
default: break;
}
value = nullptr;

View File

@ -242,7 +242,9 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
value = (max_value-min_value)*np.random.random_sample(shape)+min_value
elif dtype in (tf.int32, tf.uint8, tf.int64):
value = np.random.randint(min_value, max_value+1, shape)
return value.astype(dtype)
return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
dtype)
def create_scalar_data(dtype, min_value=-100, max_value=100):
@ -1665,6 +1667,65 @@ def make_shape_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
def make_one_hot_tests(zip_path):
"""Make a set of tests to do one_hot."""
test_parameters = [{
"indices_type": [tf.int32, tf.int64],
"indices_shape": [[3], [4, 4], [1, 5], [5, 1]],
"axis": [0, 1],
"dtype": [tf.int32, tf.int64, tf.float32],
"provide_optional_inputs": [True, False],
}]
def build_graph(parameters):
indices = tf.placeholder(
dtype=parameters["indices_type"],
name="indices",
shape=parameters["indices_shape"])
depth = tf.placeholder(dtype=tf.int32, name="depth", shape=())
if not parameters["provide_optional_inputs"]:
out = tf.one_hot(indices=indices, depth=depth)
return [indices, depth], [out]
on_value = tf.placeholder(
dtype=parameters["dtype"], name="on_value", shape=())
off_value = tf.placeholder(
dtype=parameters["dtype"], name="off_value", shape=())
out = tf.one_hot(
indices=indices,
depth=depth,
on_value=on_value,
off_value=off_value,
axis=parameters["axis"],
dtype=parameters["dtype"])
return [indices, depth, on_value, off_value], [out]
def build_inputs(parameters, sess, inputs, outputs):
input_values = [
create_tensor_data(
parameters["indices_type"],
shape=parameters["indices_shape"],
min_value=-1,
max_value=10),
create_tensor_data(tf.int32, shape=None, min_value=1, max_value=10),
]
if parameters["provide_optional_inputs"]:
input_values.append(
create_tensor_data(
parameters["dtype"], shape=None, min_value=1, max_value=10))
input_values.append(
create_tensor_data(
parameters["dtype"], shape=None, min_value=-1, max_value=0))
return input_values, sess.run(
outputs, feed_dict=dict(zip(inputs, input_values)))
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
def make_resize_bilinear_tests(zip_path):
"""Make a set of tests to do resize_bilinear."""

View File

@ -1316,6 +1316,20 @@ void ConvertResizeBilinearOperator(const Model& model,
(*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
}
void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
onehot_op->set_op("OneHot");
onehot_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 4);
for (const auto& input : src_op.inputs) {
*onehot_op->add_input() = input;
}
(*onehot_op->mutable_attr())["T"].set_type(
GetTensorFlowDataType(model, src_op.outputs[0]));
(*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
}
namespace {
// TODO(aselle): Remove when available in absl
absl::string_view FindLongestCommonPrefix(absl::string_view a,
@ -2158,6 +2172,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertLogicalNotOperator(model,
static_cast<const LogicalNotOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kOneHot) {
ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}

View File

@ -201,6 +201,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, data_type);
break;
}
case OperatorType::kOneHot: {
CHECK_EQ(op->inputs.size(), 4);
CHECK_EQ(op->outputs.size(), 1);
const ArrayDataType on_value_type =
model->GetArray(op->inputs[OneHotOperator::ON_VALUE_INPUT]).data_type;
const ArrayDataType off_value_type =
model->GetArray(op->inputs[OneHotOperator::OFF_VALUE_INPUT])
.data_type;
CHECK(on_value_type == off_value_type);
model->GetArray(op->outputs[0]).data_type = on_value_type;
break;
}
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);

View File

@ -1578,6 +1578,61 @@ void ProcessAnyOperator(Model* model, AnyOperator* op) {
}
}
void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
CHECK_EQ(op->inputs.size(), 4);
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.has_shape()) {
// Shape already propagated
return;
}
// Yield until indices dims have been resolved.
const auto& indices_array =
model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
if (!indices_array.has_shape()) {
return;
}
// Yield until depth is constant and dims have been resolved.
if (!IsConstantParameterArray(*model,
op->inputs[OneHotOperator::DEPTH_INPUT])) {
return;
}
const auto& depth_array =
model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
if (!depth_array.has_shape()) {
return;
}
CHECK(depth_array.data_type == ArrayDataType::kInt32)
<< "Depth array must be int32.";
CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
<< "Depth array must be scalar.";
const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
CHECK_GE(depth, 0) << "Depth must be non-negative.";
const int indices_dims = indices_array.shape().dimensions_count();
const int output_dims = indices_dims + 1;
const int axis = op->axis == -1 ? indices_dims : op->axis;
CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
mutable_dims->resize(output_dims);
for (int i = 0; i < output_dims; ++i) {
int dim = 0;
if (i < axis) {
dim = indices_array.shape().dims(i);
} else if (i == axis) {
dim = depth;
} else {
dim = indices_array.shape().dims(i - 1);
}
(*mutable_dims)[i] = dim;
}
}
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@ -1825,6 +1880,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kAny:
ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
break;
case OperatorType::kOneHot:
ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);

View File

@ -1833,6 +1833,27 @@ tensorflow::Status ConvertSparseToDenseOperator(
return tensorflow::Status::OK();
}
tensorflow::Status ConvertOneHotOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
CHECK_EQ(node.op(), "OneHot");
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
const auto dtype = GetDataTypeAttr(node, "T");
// TODO(b/111744875): Support DT_UINT8 and quantization.
CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
dtype == DT_BOOL);
auto op = absl::make_unique<OneHotOperator>();
op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
for (const string& input : node.input()) {
op->inputs.push_back(input);
}
op->outputs.push_back(node.name());
model->operators.emplace_back(op.release());
return tensorflow::Status::OK();
}
} // namespace
namespace internal {
@ -1909,6 +1930,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
{"NoOp", ConvertNoOpOperator},
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
{"OneHot", ConvertOneHotOperator},
{"Pack", ConvertPackOperator},
{"Pad", ConvertSimpleOperator<PadOperator, 2>},
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},

View File

@ -64,6 +64,7 @@ enum class OperatorType : uint8 {
kMaxPool,
kFakeQuant,
kMul,
kOneHot,
kRandomUniform,
kRange,
kRank,
@ -1768,6 +1769,27 @@ struct LogicalNotOperator : Operator {
LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {}
};
// OneHot operator:
//
// Inputs:
// Inputs[0]: required: indices.
// Inputs[1]: required: depth.
// Inputs[2]: required: on_value.
// Inputs[3]: required: off_value.
//
// TensorFlow equivalent: OneHot.
struct OneHotOperator : Operator {
enum Inputs {
INDICES_INPUT = 0,
DEPTH_INPUT = 1,
ON_VALUE_INPUT = 2,
OFF_VALUE_INPUT = 3,
};
OneHotOperator() : Operator(OperatorType::kOneHot) {}
int axis = -1;
};
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are

View File

@ -1053,6 +1053,23 @@ class Shape
int GetVersion(const Operator& op) const override { return 1; }
};
class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
::tflite::BuiltinOptions_OneHotOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateOneHotOptions(*builder, op.axis);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->axis = options.axis();
}
int GetVersion(const Operator& op) const override { return 1; }
};
class TensorFlowUnsupported : public BaseOperator {
public:
using BaseOperator::BaseOperator;
@ -1278,6 +1295,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kFakeQuant));
ops.emplace_back(
new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
ops.emplace_back(
new OneHot(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot));
// Custom Operators.
ops.emplace_back(

View File

@ -462,6 +462,14 @@ TEST_F(OperatorTest, BuiltinPack) {
EXPECT_EQ(op.axis, output_toco_op->axis);
}
TEST_F(OperatorTest, BuiltinOneHot) {
OneHotOperator op;
op.axis = 2;
auto output_toco_op = SerializeAndDeserialize(
GetOperator("ONE_HOT", OperatorType::kOneHot), op);
EXPECT_EQ(op.axis, output_toco_op->axis);
}
TEST_F(OperatorTest, TensorFlowUnsupported) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";

View File

@ -356,6 +356,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
HANDLE_OPERATORTYPENAME_CASE(Neg)
HANDLE_OPERATORTYPENAME_CASE(OneHot)
HANDLE_OPERATORTYPENAME_CASE(Pack)
HANDLE_OPERATORTYPENAME_CASE(Pad)
HANDLE_OPERATORTYPENAME_CASE(PadV2)