Add one_hot op support to TFLite
PiperOrigin-RevId: 206185190
This commit is contained in:
parent
0a3155f7fb
commit
6e658c0a5c
@ -248,6 +248,7 @@ def generated_test_models():
|
|||||||
"mul",
|
"mul",
|
||||||
"neg",
|
"neg",
|
||||||
"not_equal",
|
"not_equal",
|
||||||
|
"one_hot",
|
||||||
"pack",
|
"pack",
|
||||||
"pad",
|
"pad",
|
||||||
"padv2",
|
"padv2",
|
||||||
|
@ -282,6 +282,10 @@ typedef struct {
|
|||||||
int axis;
|
int axis;
|
||||||
} TfLitePackParams;
|
} TfLitePackParams;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int axis;
|
||||||
|
} TfLiteOneHotParams;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
|
@ -110,6 +110,7 @@ typedef enum {
|
|||||||
kTfLiteBuiltinReduceMax = 82,
|
kTfLiteBuiltinReduceMax = 82,
|
||||||
kTfLiteBuiltinPack = 83,
|
kTfLiteBuiltinPack = 83,
|
||||||
kTfLiteBuiltinLogicalOr = 84,
|
kTfLiteBuiltinLogicalOr = 84,
|
||||||
|
kTfLiteBuiltinOneHot = 85,
|
||||||
} TfLiteBuiltinOperator;
|
} TfLiteBuiltinOperator;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
@ -62,6 +62,7 @@ counterparts:
|
|||||||
* [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) -
|
* [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) -
|
||||||
*as long as tensors are 2D and axis is the last dimension*
|
*as long as tensors are 2D and axis is the last dimension*
|
||||||
* [tf.nn.top_k](https://www.tensorflow.org/api_docs/python/tf/nn/top_k)
|
* [tf.nn.top_k](https://www.tensorflow.org/api_docs/python/tf/nn/top_k)
|
||||||
|
* [tf.one_hot](https://www.tensorflow.org/api_docs/python/tf/one_hot)
|
||||||
* [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) - *as long as
|
* [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) - *as long as
|
||||||
mode and constant_values are not used*
|
mode and constant_values are not used*
|
||||||
* [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) -
|
* [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) -
|
||||||
|
@ -176,6 +176,7 @@ cc_library(
|
|||||||
"mfcc.cc",
|
"mfcc.cc",
|
||||||
"mul.cc",
|
"mul.cc",
|
||||||
"neg.cc",
|
"neg.cc",
|
||||||
|
"one_hot.cc",
|
||||||
"pack.cc",
|
"pack.cc",
|
||||||
"pad.cc",
|
"pad.cc",
|
||||||
"pooling.cc",
|
"pooling.cc",
|
||||||
@ -1171,6 +1172,19 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "one_hot_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["one_hot_test.cc"],
|
||||||
|
tags = ["tflite_not_portable_ios"],
|
||||||
|
deps = [
|
||||||
|
":builtin_ops",
|
||||||
|
"//tensorflow/contrib/lite:framework",
|
||||||
|
"//tensorflow/contrib/lite/kernels:test_util",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "all_files",
|
name = "all_files",
|
||||||
srcs = glob(
|
srcs = glob(
|
||||||
|
199
tensorflow/contrib/lite/kernels/one_hot.cc
Normal file
199
tensorflow/contrib/lite/kernels/one_hot.cc
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/contrib/lite/builtin_op_data.h"
|
||||||
|
#include "tensorflow/contrib/lite/context.h"
|
||||||
|
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/contrib/lite/kernels/op_macros.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace ops {
|
||||||
|
namespace builtin {
|
||||||
|
namespace one_hot {
|
||||||
|
|
||||||
|
constexpr int kIndicesTensor = 0;
|
||||||
|
constexpr int kDepthTensor = 1;
|
||||||
|
constexpr int kOnValueTensor = 2;
|
||||||
|
constexpr int kOffValueTensor = 3;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
// Convenience utility for destructuring a node into the appropriate tensors and
|
||||||
|
// data for the op. Note that this destructuring is quite cheap, so we can avoid
|
||||||
|
// allocating op-specific, persistent data on the heap.
|
||||||
|
struct OneHotContext {
|
||||||
|
OneHotContext(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
indices = GetInput(context, node, kIndicesTensor);
|
||||||
|
depth = GetInput(context, node, kDepthTensor);
|
||||||
|
on_value = GetInput(context, node, kOnValueTensor);
|
||||||
|
off_value = GetInput(context, node, kOffValueTensor);
|
||||||
|
output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
const auto* params =
|
||||||
|
reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
|
||||||
|
const int indices_dims = indices->dims->size;
|
||||||
|
axis = (params->axis == -1) ? indices_dims : params->axis;
|
||||||
|
output_dims = indices_dims + 1;
|
||||||
|
dtype = on_value->type;
|
||||||
|
}
|
||||||
|
|
||||||
|
const TfLiteTensor* indices;
|
||||||
|
const TfLiteTensor* depth;
|
||||||
|
const TfLiteTensor* on_value;
|
||||||
|
const TfLiteTensor* off_value;
|
||||||
|
TfLiteTensor* output;
|
||||||
|
int axis;
|
||||||
|
int output_dims;
|
||||||
|
TfLiteType dtype;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename TI>
|
||||||
|
void OneHotComputeImpl(const OneHotContext& op_context) {
|
||||||
|
// prefix_dim_size == # of elements before the axis
|
||||||
|
// depth == # of elements per axis
|
||||||
|
// suffix_dim_size == # of elements after the axis
|
||||||
|
int prefix_dim_size = 1;
|
||||||
|
for (int i = 0; i < op_context.axis; ++i) {
|
||||||
|
prefix_dim_size *= op_context.indices->dims->data[i];
|
||||||
|
}
|
||||||
|
const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
|
||||||
|
const int depth = *op_context.depth->data.i32;
|
||||||
|
|
||||||
|
const T on_value = *GetTensorData<T>(op_context.on_value);
|
||||||
|
const T off_value = *GetTensorData<T>(op_context.off_value);
|
||||||
|
|
||||||
|
// View the indices as a matrix of size:
|
||||||
|
// prefix_dim_size x suffix_dim_size
|
||||||
|
// View the output as a matrix of size:
|
||||||
|
// prefix_dim_size x depth x suffix_dim_size
|
||||||
|
// Then the output is:
|
||||||
|
// output(i, j, k) == (indices(i, k) == j) ? on : off
|
||||||
|
T* output = GetTensorData<T>(op_context.output);
|
||||||
|
const TI* indices = GetTensorData<TI>(op_context.indices);
|
||||||
|
for (int i = 0; i < prefix_dim_size; ++i) {
|
||||||
|
for (int j = 0; j < depth; ++j) {
|
||||||
|
for (int k = 0; k < suffix_dim_size; ++k, ++output) {
|
||||||
|
*output = static_cast<int>(indices[i * suffix_dim_size + k]) == j
|
||||||
|
? on_value
|
||||||
|
: off_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void OneHotCompute(const OneHotContext& op_context) {
|
||||||
|
if (op_context.indices->type == kTfLiteInt64) {
|
||||||
|
OneHotComputeImpl<T, int64_t>(op_context);
|
||||||
|
} else {
|
||||||
|
OneHotComputeImpl<T, int>(op_context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||||
|
const OneHotContext& op_context) {
|
||||||
|
TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
|
||||||
|
TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims);
|
||||||
|
for (int i = 0; i < op_context.output_dims; ++i) {
|
||||||
|
if (i < op_context.axis) {
|
||||||
|
output_size->data[i] = op_context.indices->dims->data[i];
|
||||||
|
} else if (i == op_context.axis) {
|
||||||
|
output_size->data[i] = *op_context.depth->data.i32;
|
||||||
|
} else {
|
||||||
|
output_size->data[i] = op_context.indices->dims->data[i - 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return context->ResizeTensor(context, op_context.output, output_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
|
||||||
|
OneHotContext op_context{context, node};
|
||||||
|
switch (op_context.dtype) {
|
||||||
|
// TODO(b/111744875): Support uint8 and quantization.
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
case kTfLiteInt16:
|
||||||
|
case kTfLiteInt32:
|
||||||
|
case kTfLiteInt64:
|
||||||
|
case kTfLiteBool:
|
||||||
|
op_context.output->type = op_context.dtype;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
context->ReportError(context, "Unknown output data type: %d",
|
||||||
|
op_context.dtype);
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
|
||||||
|
op_context.indices->type == kTfLiteInt64);
|
||||||
|
TF_LITE_ENSURE(context, op_context.axis >= 0 &&
|
||||||
|
op_context.axis < op_context.output_dims);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, op_context.on_value->type, op_context.dtype);
|
||||||
|
TF_LITE_ENSURE_EQ(context, op_context.off_value->type, op_context.dtype);
|
||||||
|
|
||||||
|
if (!IsConstantTensor(op_context.depth)) {
|
||||||
|
SetTensorToDynamic(op_context.output);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ResizeOutputTensor(context, op_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
OneHotContext op_context{context, node};
|
||||||
|
|
||||||
|
if (IsDynamicTensor(op_context.output)) {
|
||||||
|
ResizeOutputTensor(context, op_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (op_context.output->type) {
|
||||||
|
case kTfLiteFloat32:
|
||||||
|
OneHotCompute<float>(op_context);
|
||||||
|
break;
|
||||||
|
case kTfLiteInt32:
|
||||||
|
OneHotCompute<int>(op_context);
|
||||||
|
break;
|
||||||
|
case kTfLiteInt64:
|
||||||
|
OneHotCompute<int64_t>(op_context);
|
||||||
|
break;
|
||||||
|
case kTfLiteBool:
|
||||||
|
OneHotCompute<bool>(op_context);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace one_hot
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_ONE_HOT() {
|
||||||
|
static TfLiteRegistration r = {
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
one_hot::Prepare,
|
||||||
|
one_hot::Eval,
|
||||||
|
};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace builtin
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tflite
|
182
tensorflow/contrib/lite/kernels/one_hot_test.cc
Normal file
182
tensorflow/contrib/lite/kernels/one_hot_test.cc
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <initializer_list>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/contrib/lite/interpreter.h"
|
||||||
|
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||||
|
#include "tensorflow/contrib/lite/kernels/test_util.h"
|
||||||
|
#include "tensorflow/contrib/lite/model.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class OneHotOpModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
OneHotOpModel(std::initializer_list<int> input_shape, int depth_value,
|
||||||
|
TensorType dtype, int axis = -1, T on_value = 1,
|
||||||
|
T off_value = 0, TensorType indices_type = TensorType_INT32) {
|
||||||
|
indices_ = AddInput(indices_type);
|
||||||
|
int depth = AddInput(TensorType_INT32);
|
||||||
|
int on = AddInput(dtype);
|
||||||
|
int off = AddInput(dtype);
|
||||||
|
output_ = AddOutput(dtype);
|
||||||
|
SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions,
|
||||||
|
CreateOneHotOptions(builder_, axis).Union());
|
||||||
|
BuildInterpreter({input_shape});
|
||||||
|
|
||||||
|
PopulateTensor<int>(depth, {depth_value});
|
||||||
|
PopulateTensor<T>(on, {on_value});
|
||||||
|
PopulateTensor<T>(off, {off_value});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TI>
|
||||||
|
void SetIndices(std::initializer_list<TI> data) {
|
||||||
|
PopulateTensor<TI>(indices_, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
|
||||||
|
|
||||||
|
int32_t GetOutputSize() { return GetTensorSize(output_); }
|
||||||
|
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
|
||||||
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int indices_;
|
||||||
|
int output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(OneHotOpTest, BasicFloat) {
|
||||||
|
const int depth = 3;
|
||||||
|
OneHotOpModel<float> model({3}, depth, TensorType_FLOAT32);
|
||||||
|
model.SetIndices({0, 1, 2});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
|
||||||
|
EXPECT_THAT(model.GetOutput(),
|
||||||
|
ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHotOpTest, BasicInt) {
|
||||||
|
const int depth = 3;
|
||||||
|
OneHotOpModel<int> model({3}, depth, TensorType_INT32);
|
||||||
|
model.SetIndices({0, 1, 2});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHotOpTest, BasicBool) {
|
||||||
|
const int depth = 3;
|
||||||
|
OneHotOpModel<bool> model({3}, depth, TensorType_BOOL);
|
||||||
|
model.SetIndices({0, 1, 2});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
|
||||||
|
EXPECT_THAT(model.GetOutput(),
|
||||||
|
ElementsAreArray({true, false, false, false, true, false, false,
|
||||||
|
false, true}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHotOpTest, SmallDepth) {
|
||||||
|
const int depth = 1;
|
||||||
|
OneHotOpModel<int> model({3}, depth, TensorType_INT32);
|
||||||
|
model.SetIndices({0, 1, 2});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1}));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHotOpTest, BigDepth) {
|
||||||
|
const int depth = 4;
|
||||||
|
OneHotOpModel<int> model({2}, depth, TensorType_INT32);
|
||||||
|
model.SetIndices({0, 1});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHotOpTest, OnOffValues) {
|
||||||
|
const int depth = 3;
|
||||||
|
const int axis = -1;
|
||||||
|
const int on = 5;
|
||||||
|
const int off = 0;
|
||||||
|
OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
|
||||||
|
model.SetIndices({0, 2, -1, 1});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3}));
|
||||||
|
EXPECT_THAT(model.GetOutput(),
|
||||||
|
ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHotOpTest, ZeroAxis) {
|
||||||
|
const int depth = 3;
|
||||||
|
const int axis = 0;
|
||||||
|
const int on = 5;
|
||||||
|
const int off = 0;
|
||||||
|
OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
|
||||||
|
model.SetIndices({0, 2, -1, 1});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4}));
|
||||||
|
EXPECT_THAT(model.GetOutput(),
|
||||||
|
ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHotOpTest, MultiDimensionalIndices) {
|
||||||
|
const int depth = 3;
|
||||||
|
const int axis = -1;
|
||||||
|
const float on = 2;
|
||||||
|
const float off = 0;
|
||||||
|
OneHotOpModel<float> model({2, 2}, depth, TensorType_FLOAT32, axis, on, off);
|
||||||
|
model.SetIndices({0, 2, 1, -1});
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3}));
|
||||||
|
EXPECT_THAT(model.GetOutput(),
|
||||||
|
ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OneHotOpTest, Int64Indices) {
|
||||||
|
const int depth = 3;
|
||||||
|
const int axis = -1;
|
||||||
|
const int on = 1;
|
||||||
|
const int off = 0;
|
||||||
|
OneHotOpModel<int> model({3}, depth, TensorType_INT32, axis, on, off,
|
||||||
|
TensorType_INT64);
|
||||||
|
std::initializer_list<int64_t> indices = {0, 1, 2};
|
||||||
|
model.SetIndices(indices);
|
||||||
|
model.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
|
||||||
|
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
::tflite::LogToStderr();
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
@ -107,6 +107,7 @@ TfLiteRegistration* Register_SHAPE();
|
|||||||
TfLiteRegistration* Register_POW();
|
TfLiteRegistration* Register_POW();
|
||||||
TfLiteRegistration* Register_FAKE_QUANT();
|
TfLiteRegistration* Register_FAKE_QUANT();
|
||||||
TfLiteRegistration* Register_PACK();
|
TfLiteRegistration* Register_PACK();
|
||||||
|
TfLiteRegistration* Register_ONE_HOT();
|
||||||
|
|
||||||
BuiltinOpResolver::BuiltinOpResolver() {
|
BuiltinOpResolver::BuiltinOpResolver() {
|
||||||
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
|
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
|
||||||
@ -197,6 +198,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_POW, Register_POW());
|
AddBuiltin(BuiltinOperator_POW, Register_POW());
|
||||||
AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
|
AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
|
||||||
AddBuiltin(BuiltinOperator_PACK, Register_PACK());
|
AddBuiltin(BuiltinOperator_PACK, Register_PACK());
|
||||||
|
AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT());
|
||||||
|
|
||||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||||
// custom ops aren't always included by default.
|
// custom ops aren't always included by default.
|
||||||
|
@ -730,6 +730,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = static_cast<void*>(params);
|
*builtin_data = static_cast<void*>(params);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOperator_ONE_HOT: {
|
||||||
|
auto* params = MallocPOD<TfLiteOneHotParams>();
|
||||||
|
if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
|
||||||
|
params->axis = schema_params->axis();
|
||||||
|
}
|
||||||
|
*builtin_data = static_cast<void*>(params);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Below are the ops with no builtin_data strcture.
|
// Below are the ops with no builtin_data strcture.
|
||||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||||
|
@ -623,6 +623,7 @@ TfLiteStatus AddOpsAndParams(
|
|||||||
case tflite::BuiltinOperator_FAKE_QUANT:
|
case tflite::BuiltinOperator_FAKE_QUANT:
|
||||||
case tflite::BuiltinOperator_PACK:
|
case tflite::BuiltinOperator_PACK:
|
||||||
case tflite::BuiltinOperator_LOGICAL_OR:
|
case tflite::BuiltinOperator_LOGICAL_OR:
|
||||||
|
case tflite::BuiltinOperator_ONE_HOT:
|
||||||
logError("Op code %d is currently not delegated to NNAPI", builtin);
|
logError("Op code %d is currently not delegated to NNAPI", builtin);
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
break;
|
break;
|
||||||
|
@ -166,6 +166,7 @@ enum BuiltinOperator : byte {
|
|||||||
REDUCE_MAX = 82,
|
REDUCE_MAX = 82,
|
||||||
PACK = 83,
|
PACK = 83,
|
||||||
LOGICAL_OR = 84,
|
LOGICAL_OR = 84,
|
||||||
|
ONE_HOT = 85,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options for the builtin operators.
|
// Options for the builtin operators.
|
||||||
@ -230,6 +231,7 @@ union BuiltinOptions {
|
|||||||
FakeQuantOptions,
|
FakeQuantOptions,
|
||||||
PackOptions,
|
PackOptions,
|
||||||
LogicalOrOptions,
|
LogicalOrOptions,
|
||||||
|
OneHotOptions,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Padding : byte { SAME, VALID }
|
enum Padding : byte { SAME, VALID }
|
||||||
@ -549,6 +551,10 @@ table PackOptions {
|
|||||||
table LogicalOrOptions {
|
table LogicalOrOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
table OneHotOptions {
|
||||||
|
axis:int;
|
||||||
|
}
|
||||||
|
|
||||||
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
|
||||||
// builtin, or a string if the operator is custom.
|
// builtin, or a string if the operator is custom.
|
||||||
table OperatorCode {
|
table OperatorCode {
|
||||||
|
@ -211,6 +211,9 @@ struct PackOptionsT;
|
|||||||
struct LogicalOrOptions;
|
struct LogicalOrOptions;
|
||||||
struct LogicalOrOptionsT;
|
struct LogicalOrOptionsT;
|
||||||
|
|
||||||
|
struct OneHotOptions;
|
||||||
|
struct OneHotOptionsT;
|
||||||
|
|
||||||
struct OperatorCode;
|
struct OperatorCode;
|
||||||
struct OperatorCodeT;
|
struct OperatorCodeT;
|
||||||
|
|
||||||
@ -361,11 +364,12 @@ enum BuiltinOperator {
|
|||||||
BuiltinOperator_REDUCE_MAX = 82,
|
BuiltinOperator_REDUCE_MAX = 82,
|
||||||
BuiltinOperator_PACK = 83,
|
BuiltinOperator_PACK = 83,
|
||||||
BuiltinOperator_LOGICAL_OR = 84,
|
BuiltinOperator_LOGICAL_OR = 84,
|
||||||
|
BuiltinOperator_ONE_HOT = 85,
|
||||||
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
BuiltinOperator_MIN = BuiltinOperator_ADD,
|
||||||
BuiltinOperator_MAX = BuiltinOperator_LOGICAL_OR
|
BuiltinOperator_MAX = BuiltinOperator_ONE_HOT
|
||||||
};
|
};
|
||||||
|
|
||||||
inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] {
|
inline BuiltinOperator (&EnumValuesBuiltinOperator())[85] {
|
||||||
static BuiltinOperator values[] = {
|
static BuiltinOperator values[] = {
|
||||||
BuiltinOperator_ADD,
|
BuiltinOperator_ADD,
|
||||||
BuiltinOperator_AVERAGE_POOL_2D,
|
BuiltinOperator_AVERAGE_POOL_2D,
|
||||||
@ -450,7 +454,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] {
|
|||||||
BuiltinOperator_REDUCE_PROD,
|
BuiltinOperator_REDUCE_PROD,
|
||||||
BuiltinOperator_REDUCE_MAX,
|
BuiltinOperator_REDUCE_MAX,
|
||||||
BuiltinOperator_PACK,
|
BuiltinOperator_PACK,
|
||||||
BuiltinOperator_LOGICAL_OR
|
BuiltinOperator_LOGICAL_OR,
|
||||||
|
BuiltinOperator_ONE_HOT
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
@ -542,6 +547,7 @@ inline const char **EnumNamesBuiltinOperator() {
|
|||||||
"REDUCE_MAX",
|
"REDUCE_MAX",
|
||||||
"PACK",
|
"PACK",
|
||||||
"LOGICAL_OR",
|
"LOGICAL_OR",
|
||||||
|
"ONE_HOT",
|
||||||
nullptr
|
nullptr
|
||||||
};
|
};
|
||||||
return names;
|
return names;
|
||||||
@ -614,11 +620,12 @@ enum BuiltinOptions {
|
|||||||
BuiltinOptions_FakeQuantOptions = 58,
|
BuiltinOptions_FakeQuantOptions = 58,
|
||||||
BuiltinOptions_PackOptions = 59,
|
BuiltinOptions_PackOptions = 59,
|
||||||
BuiltinOptions_LogicalOrOptions = 60,
|
BuiltinOptions_LogicalOrOptions = 60,
|
||||||
|
BuiltinOptions_OneHotOptions = 61,
|
||||||
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
BuiltinOptions_MIN = BuiltinOptions_NONE,
|
||||||
BuiltinOptions_MAX = BuiltinOptions_LogicalOrOptions
|
BuiltinOptions_MAX = BuiltinOptions_OneHotOptions
|
||||||
};
|
};
|
||||||
|
|
||||||
inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] {
|
inline BuiltinOptions (&EnumValuesBuiltinOptions())[62] {
|
||||||
static BuiltinOptions values[] = {
|
static BuiltinOptions values[] = {
|
||||||
BuiltinOptions_NONE,
|
BuiltinOptions_NONE,
|
||||||
BuiltinOptions_Conv2DOptions,
|
BuiltinOptions_Conv2DOptions,
|
||||||
@ -680,7 +687,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] {
|
|||||||
BuiltinOptions_ArgMinOptions,
|
BuiltinOptions_ArgMinOptions,
|
||||||
BuiltinOptions_FakeQuantOptions,
|
BuiltinOptions_FakeQuantOptions,
|
||||||
BuiltinOptions_PackOptions,
|
BuiltinOptions_PackOptions,
|
||||||
BuiltinOptions_LogicalOrOptions
|
BuiltinOptions_LogicalOrOptions,
|
||||||
|
BuiltinOptions_OneHotOptions
|
||||||
};
|
};
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
@ -748,6 +756,7 @@ inline const char **EnumNamesBuiltinOptions() {
|
|||||||
"FakeQuantOptions",
|
"FakeQuantOptions",
|
||||||
"PackOptions",
|
"PackOptions",
|
||||||
"LogicalOrOptions",
|
"LogicalOrOptions",
|
||||||
|
"OneHotOptions",
|
||||||
nullptr
|
nullptr
|
||||||
};
|
};
|
||||||
return names;
|
return names;
|
||||||
@ -1002,6 +1011,10 @@ template<> struct BuiltinOptionsTraits<LogicalOrOptions> {
|
|||||||
static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions;
|
static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<> struct BuiltinOptionsTraits<OneHotOptions> {
|
||||||
|
static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions;
|
||||||
|
};
|
||||||
|
|
||||||
struct BuiltinOptionsUnion {
|
struct BuiltinOptionsUnion {
|
||||||
BuiltinOptions type;
|
BuiltinOptions type;
|
||||||
void *value;
|
void *value;
|
||||||
@ -1513,6 +1526,14 @@ struct BuiltinOptionsUnion {
|
|||||||
return type == BuiltinOptions_LogicalOrOptions ?
|
return type == BuiltinOptions_LogicalOrOptions ?
|
||||||
reinterpret_cast<const LogicalOrOptionsT *>(value) : nullptr;
|
reinterpret_cast<const LogicalOrOptionsT *>(value) : nullptr;
|
||||||
}
|
}
|
||||||
|
OneHotOptionsT *AsOneHotOptions() {
|
||||||
|
return type == BuiltinOptions_OneHotOptions ?
|
||||||
|
reinterpret_cast<OneHotOptionsT *>(value) : nullptr;
|
||||||
|
}
|
||||||
|
const OneHotOptionsT *AsOneHotOptions() const {
|
||||||
|
return type == BuiltinOptions_OneHotOptions ?
|
||||||
|
reinterpret_cast<const OneHotOptionsT *>(value) : nullptr;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
|
||||||
@ -5452,6 +5473,60 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(
|
|||||||
|
|
||||||
flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
|
||||||
|
struct OneHotOptionsT : public flatbuffers::NativeTable {
|
||||||
|
typedef OneHotOptions TableType;
|
||||||
|
int32_t axis;
|
||||||
|
OneHotOptionsT()
|
||||||
|
: axis(0) {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
||||||
|
typedef OneHotOptionsT NativeTableType;
|
||||||
|
enum {
|
||||||
|
VT_AXIS = 4
|
||||||
|
};
|
||||||
|
int32_t axis() const {
|
||||||
|
return GetField<int32_t>(VT_AXIS, 0);
|
||||||
|
}
|
||||||
|
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||||
|
return VerifyTableStart(verifier) &&
|
||||||
|
VerifyField<int32_t>(verifier, VT_AXIS) &&
|
||||||
|
verifier.EndTable();
|
||||||
|
}
|
||||||
|
OneHotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
|
void UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||||
|
static flatbuffers::Offset<OneHotOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OneHotOptionsBuilder {
|
||||||
|
flatbuffers::FlatBufferBuilder &fbb_;
|
||||||
|
flatbuffers::uoffset_t start_;
|
||||||
|
void add_axis(int32_t axis) {
|
||||||
|
fbb_.AddElement<int32_t>(OneHotOptions::VT_AXIS, axis, 0);
|
||||||
|
}
|
||||||
|
explicit OneHotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||||
|
: fbb_(_fbb) {
|
||||||
|
start_ = fbb_.StartTable();
|
||||||
|
}
|
||||||
|
OneHotOptionsBuilder &operator=(const OneHotOptionsBuilder &);
|
||||||
|
flatbuffers::Offset<OneHotOptions> Finish() {
|
||||||
|
const auto end = fbb_.EndTable(start_);
|
||||||
|
auto o = flatbuffers::Offset<OneHotOptions>(end);
|
||||||
|
return o;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(
|
||||||
|
flatbuffers::FlatBufferBuilder &_fbb,
|
||||||
|
int32_t axis = 0) {
|
||||||
|
OneHotOptionsBuilder builder_(_fbb);
|
||||||
|
builder_.add_axis(axis);
|
||||||
|
return builder_.Finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||||
|
|
||||||
struct OperatorCodeT : public flatbuffers::NativeTable {
|
struct OperatorCodeT : public flatbuffers::NativeTable {
|
||||||
typedef OperatorCode TableType;
|
typedef OperatorCode TableType;
|
||||||
BuiltinOperator builtin_code;
|
BuiltinOperator builtin_code;
|
||||||
@ -5765,6 +5840,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
|
|||||||
const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const {
|
const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const {
|
||||||
return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast<const LogicalOrOptions *>(builtin_options()) : nullptr;
|
return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast<const LogicalOrOptions *>(builtin_options()) : nullptr;
|
||||||
}
|
}
|
||||||
|
const OneHotOptions *builtin_options_as_OneHotOptions() const {
|
||||||
|
return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast<const OneHotOptions *>(builtin_options()) : nullptr;
|
||||||
|
}
|
||||||
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
const flatbuffers::Vector<uint8_t> *custom_options() const {
|
||||||
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
|
||||||
}
|
}
|
||||||
@ -6036,6 +6114,10 @@ template<> inline const LogicalOrOptions *Operator::builtin_options_as<LogicalOr
|
|||||||
return builtin_options_as_LogicalOrOptions();
|
return builtin_options_as_LogicalOrOptions();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<> inline const OneHotOptions *Operator::builtin_options_as<OneHotOptions>() const {
|
||||||
|
return builtin_options_as_OneHotOptions();
|
||||||
|
}
|
||||||
|
|
||||||
struct OperatorBuilder {
|
struct OperatorBuilder {
|
||||||
flatbuffers::FlatBufferBuilder &fbb_;
|
flatbuffers::FlatBufferBuilder &fbb_;
|
||||||
flatbuffers::uoffset_t start_;
|
flatbuffers::uoffset_t start_;
|
||||||
@ -8151,6 +8233,32 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers:
|
|||||||
_fbb);
|
_fbb);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline OneHotOptionsT *OneHotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
auto _o = new OneHotOptionsT();
|
||||||
|
UnPackTo(_o, _resolver);
|
||||||
|
return _o;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
|
(void)_o;
|
||||||
|
(void)_resolver;
|
||||||
|
{ auto _e = axis(); _o->axis = _e; };
|
||||||
|
}
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<OneHotOptions> OneHotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
|
return CreateOneHotOptions(_fbb, _o, _rehasher);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
|
||||||
|
(void)_rehasher;
|
||||||
|
(void)_o;
|
||||||
|
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OneHotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
|
||||||
|
auto _axis = _o->axis;
|
||||||
|
return tflite::CreateOneHotOptions(
|
||||||
|
_fbb,
|
||||||
|
_axis);
|
||||||
|
}
|
||||||
|
|
||||||
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
|
||||||
auto _o = new OperatorCodeT();
|
auto _o = new OperatorCodeT();
|
||||||
UnPackTo(_o, _resolver);
|
UnPackTo(_o, _resolver);
|
||||||
@ -8580,6 +8688,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
|
|||||||
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
|
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
|
||||||
return verifier.VerifyTable(ptr);
|
return verifier.VerifyTable(ptr);
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_OneHotOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
|
||||||
|
return verifier.VerifyTable(ptr);
|
||||||
|
}
|
||||||
default: return false;
|
default: return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -8838,6 +8950,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
|
|||||||
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
|
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
|
||||||
return ptr->UnPack(resolver);
|
return ptr->UnPack(resolver);
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_OneHotOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
|
||||||
|
return ptr->UnPack(resolver);
|
||||||
|
}
|
||||||
default: return nullptr;
|
default: return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -9084,6 +9200,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
|
|||||||
auto ptr = reinterpret_cast<const LogicalOrOptionsT *>(value);
|
auto ptr = reinterpret_cast<const LogicalOrOptionsT *>(value);
|
||||||
return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union();
|
return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union();
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_OneHotOptions: {
|
||||||
|
auto ptr = reinterpret_cast<const OneHotOptionsT *>(value);
|
||||||
|
return CreateOneHotOptions(_fbb, ptr, _rehasher).Union();
|
||||||
|
}
|
||||||
default: return 0;
|
default: return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -9330,6 +9450,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
|
|||||||
value = new LogicalOrOptionsT(*reinterpret_cast<LogicalOrOptionsT *>(u.value));
|
value = new LogicalOrOptionsT(*reinterpret_cast<LogicalOrOptionsT *>(u.value));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_OneHotOptions: {
|
||||||
|
value = new OneHotOptionsT(*reinterpret_cast<OneHotOptionsT *>(u.value));
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -9637,6 +9761,11 @@ inline void BuiltinOptionsUnion::Reset() {
|
|||||||
delete ptr;
|
delete ptr;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case BuiltinOptions_OneHotOptions: {
|
||||||
|
auto ptr = reinterpret_cast<OneHotOptionsT *>(value);
|
||||||
|
delete ptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
value = nullptr;
|
value = nullptr;
|
||||||
|
@ -242,7 +242,9 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
|
|||||||
value = (max_value-min_value)*np.random.random_sample(shape)+min_value
|
value = (max_value-min_value)*np.random.random_sample(shape)+min_value
|
||||||
elif dtype in (tf.int32, tf.uint8, tf.int64):
|
elif dtype in (tf.int32, tf.uint8, tf.int64):
|
||||||
value = np.random.randint(min_value, max_value+1, shape)
|
value = np.random.randint(min_value, max_value+1, shape)
|
||||||
return value.astype(dtype)
|
|
||||||
|
return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
|
||||||
|
dtype)
|
||||||
|
|
||||||
|
|
||||||
def create_scalar_data(dtype, min_value=-100, max_value=100):
|
def create_scalar_data(dtype, min_value=-100, max_value=100):
|
||||||
@ -1665,6 +1667,65 @@ def make_shape_tests(zip_path):
|
|||||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_one_hot_tests(zip_path):
|
||||||
|
"""Make a set of tests to do one_hot."""
|
||||||
|
|
||||||
|
test_parameters = [{
|
||||||
|
"indices_type": [tf.int32, tf.int64],
|
||||||
|
"indices_shape": [[3], [4, 4], [1, 5], [5, 1]],
|
||||||
|
"axis": [0, 1],
|
||||||
|
"dtype": [tf.int32, tf.int64, tf.float32],
|
||||||
|
"provide_optional_inputs": [True, False],
|
||||||
|
}]
|
||||||
|
|
||||||
|
def build_graph(parameters):
|
||||||
|
indices = tf.placeholder(
|
||||||
|
dtype=parameters["indices_type"],
|
||||||
|
name="indices",
|
||||||
|
shape=parameters["indices_shape"])
|
||||||
|
depth = tf.placeholder(dtype=tf.int32, name="depth", shape=())
|
||||||
|
|
||||||
|
if not parameters["provide_optional_inputs"]:
|
||||||
|
out = tf.one_hot(indices=indices, depth=depth)
|
||||||
|
return [indices, depth], [out]
|
||||||
|
|
||||||
|
on_value = tf.placeholder(
|
||||||
|
dtype=parameters["dtype"], name="on_value", shape=())
|
||||||
|
off_value = tf.placeholder(
|
||||||
|
dtype=parameters["dtype"], name="off_value", shape=())
|
||||||
|
out = tf.one_hot(
|
||||||
|
indices=indices,
|
||||||
|
depth=depth,
|
||||||
|
on_value=on_value,
|
||||||
|
off_value=off_value,
|
||||||
|
axis=parameters["axis"],
|
||||||
|
dtype=parameters["dtype"])
|
||||||
|
return [indices, depth, on_value, off_value], [out]
|
||||||
|
|
||||||
|
def build_inputs(parameters, sess, inputs, outputs):
|
||||||
|
input_values = [
|
||||||
|
create_tensor_data(
|
||||||
|
parameters["indices_type"],
|
||||||
|
shape=parameters["indices_shape"],
|
||||||
|
min_value=-1,
|
||||||
|
max_value=10),
|
||||||
|
create_tensor_data(tf.int32, shape=None, min_value=1, max_value=10),
|
||||||
|
]
|
||||||
|
|
||||||
|
if parameters["provide_optional_inputs"]:
|
||||||
|
input_values.append(
|
||||||
|
create_tensor_data(
|
||||||
|
parameters["dtype"], shape=None, min_value=1, max_value=10))
|
||||||
|
input_values.append(
|
||||||
|
create_tensor_data(
|
||||||
|
parameters["dtype"], shape=None, min_value=-1, max_value=0))
|
||||||
|
|
||||||
|
return input_values, sess.run(
|
||||||
|
outputs, feed_dict=dict(zip(inputs, input_values)))
|
||||||
|
|
||||||
|
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||||
|
|
||||||
|
|
||||||
def make_resize_bilinear_tests(zip_path):
|
def make_resize_bilinear_tests(zip_path):
|
||||||
"""Make a set of tests to do resize_bilinear."""
|
"""Make a set of tests to do resize_bilinear."""
|
||||||
|
|
||||||
|
@ -1316,6 +1316,20 @@ void ConvertResizeBilinearOperator(const Model& model,
|
|||||||
(*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
|
(*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
|
||||||
|
GraphDef* tensorflow_graph) {
|
||||||
|
tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
|
||||||
|
onehot_op->set_op("OneHot");
|
||||||
|
onehot_op->set_name(src_op.outputs[0]);
|
||||||
|
CHECK_EQ(src_op.inputs.size(), 4);
|
||||||
|
for (const auto& input : src_op.inputs) {
|
||||||
|
*onehot_op->add_input() = input;
|
||||||
|
}
|
||||||
|
(*onehot_op->mutable_attr())["T"].set_type(
|
||||||
|
GetTensorFlowDataType(model, src_op.outputs[0]));
|
||||||
|
(*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// TODO(aselle): Remove when available in absl
|
// TODO(aselle): Remove when available in absl
|
||||||
absl::string_view FindLongestCommonPrefix(absl::string_view a,
|
absl::string_view FindLongestCommonPrefix(absl::string_view a,
|
||||||
@ -2158,6 +2172,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
|
|||||||
ConvertLogicalNotOperator(model,
|
ConvertLogicalNotOperator(model,
|
||||||
static_cast<const LogicalNotOperator&>(src_op),
|
static_cast<const LogicalNotOperator&>(src_op),
|
||||||
tensorflow_graph);
|
tensorflow_graph);
|
||||||
|
} else if (src_op.type == OperatorType::kOneHot) {
|
||||||
|
ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
|
||||||
|
tensorflow_graph);
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
|
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
|
||||||
}
|
}
|
||||||
|
@ -201,6 +201,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
|
|||||||
SetDataTypeForAllOutputs(model, op, data_type);
|
SetDataTypeForAllOutputs(model, op, data_type);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case OperatorType::kOneHot: {
|
||||||
|
CHECK_EQ(op->inputs.size(), 4);
|
||||||
|
CHECK_EQ(op->outputs.size(), 1);
|
||||||
|
const ArrayDataType on_value_type =
|
||||||
|
model->GetArray(op->inputs[OneHotOperator::ON_VALUE_INPUT]).data_type;
|
||||||
|
const ArrayDataType off_value_type =
|
||||||
|
model->GetArray(op->inputs[OneHotOperator::OFF_VALUE_INPUT])
|
||||||
|
.data_type;
|
||||||
|
CHECK(on_value_type == off_value_type);
|
||||||
|
model->GetArray(op->outputs[0]).data_type = on_value_type;
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// These operators produce outputs with the same type as their 1st input
|
// These operators produce outputs with the same type as their 1st input
|
||||||
CHECK_GT(op->inputs.size(), 0);
|
CHECK_GT(op->inputs.size(), 0);
|
||||||
|
@ -1578,6 +1578,61 @@ void ProcessAnyOperator(Model* model, AnyOperator* op) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
|
||||||
|
CHECK_EQ(op->inputs.size(), 4);
|
||||||
|
CHECK_EQ(op->outputs.size(), 1);
|
||||||
|
auto& output_array = model->GetArray(op->outputs[0]);
|
||||||
|
if (output_array.has_shape()) {
|
||||||
|
// Shape already propagated
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Yield until indices dims have been resolved.
|
||||||
|
const auto& indices_array =
|
||||||
|
model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
|
||||||
|
if (!indices_array.has_shape()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Yield until depth is constant and dims have been resolved.
|
||||||
|
if (!IsConstantParameterArray(*model,
|
||||||
|
op->inputs[OneHotOperator::DEPTH_INPUT])) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto& depth_array =
|
||||||
|
model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
|
||||||
|
if (!depth_array.has_shape()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK(depth_array.data_type == ArrayDataType::kInt32)
|
||||||
|
<< "Depth array must be int32.";
|
||||||
|
CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
|
||||||
|
<< "Depth array must be scalar.";
|
||||||
|
|
||||||
|
const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
|
||||||
|
CHECK_GE(depth, 0) << "Depth must be non-negative.";
|
||||||
|
|
||||||
|
const int indices_dims = indices_array.shape().dimensions_count();
|
||||||
|
const int output_dims = indices_dims + 1;
|
||||||
|
const int axis = op->axis == -1 ? indices_dims : op->axis;
|
||||||
|
CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
|
||||||
|
|
||||||
|
auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
|
||||||
|
mutable_dims->resize(output_dims);
|
||||||
|
for (int i = 0; i < output_dims; ++i) {
|
||||||
|
int dim = 0;
|
||||||
|
if (i < axis) {
|
||||||
|
dim = indices_array.shape().dims(i);
|
||||||
|
} else if (i == axis) {
|
||||||
|
dim = depth;
|
||||||
|
} else {
|
||||||
|
dim = indices_array.shape().dims(i - 1);
|
||||||
|
}
|
||||||
|
(*mutable_dims)[i] = dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
|
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
|
||||||
@ -1825,6 +1880,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
|
|||||||
case OperatorType::kAny:
|
case OperatorType::kAny:
|
||||||
ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
|
ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
|
||||||
break;
|
break;
|
||||||
|
case OperatorType::kOneHot:
|
||||||
|
ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
// Unimplemented, another graph transformation should drop it.
|
// Unimplemented, another graph transformation should drop it.
|
||||||
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
|
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
|
||||||
|
@ -1833,6 +1833,27 @@ tensorflow::Status ConvertSparseToDenseOperator(
|
|||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::Status ConvertOneHotOperator(
|
||||||
|
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
|
||||||
|
Model* model) {
|
||||||
|
CHECK_EQ(node.op(), "OneHot");
|
||||||
|
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
|
||||||
|
|
||||||
|
const auto dtype = GetDataTypeAttr(node, "T");
|
||||||
|
// TODO(b/111744875): Support DT_UINT8 and quantization.
|
||||||
|
CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
|
||||||
|
dtype == DT_BOOL);
|
||||||
|
|
||||||
|
auto op = absl::make_unique<OneHotOperator>();
|
||||||
|
op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
|
||||||
|
for (const string& input : node.input()) {
|
||||||
|
op->inputs.push_back(input);
|
||||||
|
}
|
||||||
|
op->outputs.push_back(node.name());
|
||||||
|
model->operators.emplace_back(op.release());
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
@ -1909,6 +1930,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
|||||||
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
|
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
|
||||||
{"NoOp", ConvertNoOpOperator},
|
{"NoOp", ConvertNoOpOperator},
|
||||||
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
|
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
|
||||||
|
{"OneHot", ConvertOneHotOperator},
|
||||||
{"Pack", ConvertPackOperator},
|
{"Pack", ConvertPackOperator},
|
||||||
{"Pad", ConvertSimpleOperator<PadOperator, 2>},
|
{"Pad", ConvertSimpleOperator<PadOperator, 2>},
|
||||||
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},
|
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},
|
||||||
|
@ -64,6 +64,7 @@ enum class OperatorType : uint8 {
|
|||||||
kMaxPool,
|
kMaxPool,
|
||||||
kFakeQuant,
|
kFakeQuant,
|
||||||
kMul,
|
kMul,
|
||||||
|
kOneHot,
|
||||||
kRandomUniform,
|
kRandomUniform,
|
||||||
kRange,
|
kRange,
|
||||||
kRank,
|
kRank,
|
||||||
@ -1768,6 +1769,27 @@ struct LogicalNotOperator : Operator {
|
|||||||
LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {}
|
LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// OneHot operator:
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// Inputs[0]: required: indices.
|
||||||
|
// Inputs[1]: required: depth.
|
||||||
|
// Inputs[2]: required: on_value.
|
||||||
|
// Inputs[3]: required: off_value.
|
||||||
|
//
|
||||||
|
// TensorFlow equivalent: OneHot.
|
||||||
|
struct OneHotOperator : Operator {
|
||||||
|
enum Inputs {
|
||||||
|
INDICES_INPUT = 0,
|
||||||
|
DEPTH_INPUT = 1,
|
||||||
|
ON_VALUE_INPUT = 2,
|
||||||
|
OFF_VALUE_INPUT = 3,
|
||||||
|
};
|
||||||
|
|
||||||
|
OneHotOperator() : Operator(OperatorType::kOneHot) {}
|
||||||
|
int axis = -1;
|
||||||
|
};
|
||||||
|
|
||||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||||
|
@ -1053,6 +1053,23 @@ class Shape
|
|||||||
int GetVersion(const Operator& op) const override { return 1; }
|
int GetVersion(const Operator& op) const override { return 1; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
|
||||||
|
::tflite::BuiltinOptions_OneHotOptions> {
|
||||||
|
public:
|
||||||
|
using BuiltinOperator::BuiltinOperator;
|
||||||
|
flatbuffers::Offset<TfLiteOptions> WriteOptions(
|
||||||
|
const TocoOperator& op,
|
||||||
|
flatbuffers::FlatBufferBuilder* builder) const override {
|
||||||
|
return ::tflite::CreateOneHotOptions(*builder, op.axis);
|
||||||
|
}
|
||||||
|
void ReadOptions(const TfLiteOptions& options,
|
||||||
|
TocoOperator* op) const override {
|
||||||
|
op->axis = options.axis();
|
||||||
|
}
|
||||||
|
|
||||||
|
int GetVersion(const Operator& op) const override { return 1; }
|
||||||
|
};
|
||||||
|
|
||||||
class TensorFlowUnsupported : public BaseOperator {
|
class TensorFlowUnsupported : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
using BaseOperator::BaseOperator;
|
using BaseOperator::BaseOperator;
|
||||||
@ -1278,6 +1295,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
|
|||||||
OperatorType::kFakeQuant));
|
OperatorType::kFakeQuant));
|
||||||
ops.emplace_back(
|
ops.emplace_back(
|
||||||
new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
|
new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
|
||||||
|
ops.emplace_back(
|
||||||
|
new OneHot(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot));
|
||||||
|
|
||||||
// Custom Operators.
|
// Custom Operators.
|
||||||
ops.emplace_back(
|
ops.emplace_back(
|
||||||
|
@ -462,6 +462,14 @@ TEST_F(OperatorTest, BuiltinPack) {
|
|||||||
EXPECT_EQ(op.axis, output_toco_op->axis);
|
EXPECT_EQ(op.axis, output_toco_op->axis);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, BuiltinOneHot) {
|
||||||
|
OneHotOperator op;
|
||||||
|
op.axis = 2;
|
||||||
|
auto output_toco_op = SerializeAndDeserialize(
|
||||||
|
GetOperator("ONE_HOT", OperatorType::kOneHot), op);
|
||||||
|
EXPECT_EQ(op.axis, output_toco_op->axis);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(OperatorTest, TensorFlowUnsupported) {
|
TEST_F(OperatorTest, TensorFlowUnsupported) {
|
||||||
TensorFlowUnsupportedOperator op;
|
TensorFlowUnsupportedOperator op;
|
||||||
op.tensorflow_op = "MyCustomUnsupportedOp";
|
op.tensorflow_op = "MyCustomUnsupportedOp";
|
||||||
|
@ -356,6 +356,7 @@ const char* OperatorTypeName(OperatorType type) {
|
|||||||
HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
|
HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
|
||||||
HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
|
HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
|
||||||
HANDLE_OPERATORTYPENAME_CASE(Neg)
|
HANDLE_OPERATORTYPENAME_CASE(Neg)
|
||||||
|
HANDLE_OPERATORTYPENAME_CASE(OneHot)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(Pack)
|
HANDLE_OPERATORTYPENAME_CASE(Pack)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(Pad)
|
HANDLE_OPERATORTYPENAME_CASE(Pad)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(PadV2)
|
HANDLE_OPERATORTYPENAME_CASE(PadV2)
|
||||||
|
Loading…
Reference in New Issue
Block a user