Implement tf.where.
PiperOrigin-RevId: 234398807
This commit is contained in:
parent
439f3eb035
commit
d6ae732706
@ -1023,6 +1023,22 @@ Outputs {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**WHERE**
|
||||||
|
|
||||||
|
```
|
||||||
|
Inputs {
|
||||||
|
0: A tensor of type bool.
|
||||||
|
1: A tensor which may have the same shape as condition. If condition is rank
|
||||||
|
1, x may have higher rank, but its first dimension must match the size of
|
||||||
|
condition.
|
||||||
|
2: A tensor with the same shape and type as x.
|
||||||
|
}
|
||||||
|
Outputs {
|
||||||
|
0: A tensor with the same type and shape as x, y if they are non-None, or
|
||||||
|
a tensor with shape (num_true, dim_size(condition)).
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
**ZEROS_LIKE**
|
**ZEROS_LIKE**
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -227,6 +227,7 @@ cc_library(
|
|||||||
"unidirectional_sequence_rnn.cc",
|
"unidirectional_sequence_rnn.cc",
|
||||||
"unique.cc",
|
"unique.cc",
|
||||||
"unpack.cc",
|
"unpack.cc",
|
||||||
|
"where.cc",
|
||||||
"while.cc",
|
"while.cc",
|
||||||
"zeros_like.cc",
|
"zeros_like.cc",
|
||||||
],
|
],
|
||||||
@ -1187,6 +1188,19 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "where_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["where_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":builtin_ops",
|
||||||
|
"//tensorflow/lite:builtin_op_data",
|
||||||
|
"//tensorflow/lite:framework",
|
||||||
|
"//tensorflow/lite/kernels:test_util",
|
||||||
|
"@com_google_googletest//:gtest",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "zeros_like_test",
|
name = "zeros_like_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -4586,6 +4586,34 @@ void RankOneSelect(const RuntimeShape& input_condition_shape,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename D, typename T>
|
||||||
|
void SelectTrueCoords(const RuntimeShape& input_condition_shape,
|
||||||
|
const D* input_condition_data, T* output_data) {
|
||||||
|
const size_t size = input_condition_shape.FlatSize();
|
||||||
|
const size_t cond_rank = input_condition_shape.DimensionsCount();
|
||||||
|
|
||||||
|
std::vector<int> dims_to_count(cond_rank, 0);
|
||||||
|
int cur_flat_size = size;
|
||||||
|
for (int i = 0; i < cond_rank; ++i) {
|
||||||
|
dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
|
||||||
|
cur_flat_size = dims_to_count[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
int output_index = 0;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
if (input_condition_data[i]) {
|
||||||
|
// Insert the coordinate of the current item (row major) into output.
|
||||||
|
int flat_index = i;
|
||||||
|
for (int j = 0; j < cond_rank; ++j) {
|
||||||
|
int coord_j = flat_index / dims_to_count[j];
|
||||||
|
output_data[output_index * cond_rank + j] = coord_j;
|
||||||
|
flat_index %= dims_to_count[j];
|
||||||
|
}
|
||||||
|
output_index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// For easy implementation, the indices is always a vector of size-4 vectors.
|
// For easy implementation, the indices is always a vector of size-4 vectors.
|
||||||
template <typename T, typename TI>
|
template <typename T, typename TI>
|
||||||
inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
|
inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
|
||||||
|
@ -135,6 +135,7 @@ TfLiteRegistration* Register_UNIQUE();
|
|||||||
TfLiteRegistration* Register_REVERSE_V2();
|
TfLiteRegistration* Register_REVERSE_V2();
|
||||||
TfLiteRegistration* Register_ADD_N();
|
TfLiteRegistration* Register_ADD_N();
|
||||||
TfLiteRegistration* Register_GATHER_ND();
|
TfLiteRegistration* Register_GATHER_ND();
|
||||||
|
TfLiteRegistration* Register_WHERE();
|
||||||
|
|
||||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||||
context->ReportError(
|
context->ReportError(
|
||||||
@ -361,6 +362,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2());
|
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2());
|
||||||
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
|
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
|
||||||
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND());
|
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND());
|
||||||
|
AddBuiltin(BuiltinOperator_WHERE, Register_WHERE());
|
||||||
|
|
||||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||||
// custom ops aren't always included by default.
|
// custom ops aren't always included by default.
|
||||||
|
105
tensorflow/lite/kernels/where.cc
Normal file
105
tensorflow/lite/kernels/where.cc
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/lite/c/c_api_internal.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/types.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace ops {
|
||||||
|
namespace builtin {
|
||||||
|
namespace where {
|
||||||
|
|
||||||
|
constexpr int kInputConditionTensor = 0;
|
||||||
|
constexpr int kOutputTensor = 0;
|
||||||
|
|
||||||
|
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||||
|
const TfLiteTensor* cond_tensor,
|
||||||
|
TfLiteTensor* output_tensor) {
|
||||||
|
// Output tensor should have shape:
|
||||||
|
// (num_true, cond_rank), where num_true denotes the number of true values
|
||||||
|
// in condition.
|
||||||
|
const RuntimeShape& cond_shape = GetTensorShape(cond_tensor);
|
||||||
|
const int size = cond_shape.FlatSize();
|
||||||
|
const int cond_rank = cond_shape.DimensionsCount();
|
||||||
|
const bool* cond_data = GetTensorData<bool>(cond_tensor);
|
||||||
|
|
||||||
|
int true_count = 0;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
if (cond_data[i]) {
|
||||||
|
true_count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TfLiteIntArray* output_dims = TfLiteIntArrayCreate(2);
|
||||||
|
output_dims->data[0] = true_count;
|
||||||
|
output_dims->data[1] = cond_rank;
|
||||||
|
return context->ResizeTensor(context, output_tensor, output_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||||
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||||
|
|
||||||
|
const TfLiteTensor* cond_tensor =
|
||||||
|
GetInput(context, node, kInputConditionTensor);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
if (cond_tensor->type != kTfLiteBool) {
|
||||||
|
context->ReportError(context,
|
||||||
|
"Condition tensor must be of type bool, but saw '%s'.",
|
||||||
|
TfLiteTypeGetName(cond_tensor->type));
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// As output will be a 2D tensor of indices, we use int32 as data type.
|
||||||
|
output->type = kTfLiteInt32;
|
||||||
|
|
||||||
|
// Exit early if cond is a non-const tensor. Set output tensor to dynamic so
|
||||||
|
// output size can be determined in Eval.
|
||||||
|
if (!IsConstantTensor(cond_tensor)) {
|
||||||
|
SetTensorToDynamic(output);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
return ResizeOutputTensor(context, cond_tensor, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteTensor* cond_tensor =
|
||||||
|
GetInput(context, node, kInputConditionTensor);
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
|
if (IsDynamicTensor(output)) {
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
ResizeOutputTensor(context, cond_tensor, output));
|
||||||
|
}
|
||||||
|
|
||||||
|
reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
|
||||||
|
GetTensorData<bool>(cond_tensor),
|
||||||
|
GetTensorData<int32_t>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace where
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_WHERE() {
|
||||||
|
static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
|
||||||
|
where::Prepare, where::Eval};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace builtin
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace tflite
|
161
tensorflow/lite/kernels/where_test.cc
Normal file
161
tensorflow/lite/kernels/where_test.cc
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
|
namespace tflite {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
|
class BaseWhereOpModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
BaseWhereOpModel(const TensorData& input, const TensorData& output) {
|
||||||
|
input_ = AddInput(input);
|
||||||
|
output_ = AddOutput(output);
|
||||||
|
SetBuiltinOp(BuiltinOperator_WHERE, BuiltinOptions_WhereOptions,
|
||||||
|
CreateWhereOptions(builder_).Union());
|
||||||
|
BuildInterpreter({GetShape(input_)});
|
||||||
|
}
|
||||||
|
|
||||||
|
int input() { return input_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int input_;
|
||||||
|
int output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class IntegerWhereOpModel : public BaseWhereOpModel {
|
||||||
|
public:
|
||||||
|
using BaseWhereOpModel::BaseWhereOpModel;
|
||||||
|
|
||||||
|
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(WhereOpTest, SelectFromVectorNoResult) {
|
||||||
|
IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT32, {}});
|
||||||
|
m.PopulateTensor<bool>(m.input(), {false, false, false});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput().size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WhereOpTest, SelectFromVector) {
|
||||||
|
IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT32, {}});
|
||||||
|
m.PopulateTensor<bool>(m.input(), {true, false, true});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WhereOpTest, SelectFromMatrixNoResult) {
|
||||||
|
IntegerWhereOpModel m({TensorType_BOOL, {3, 3}}, {TensorType_INT32, {}});
|
||||||
|
m.PopulateTensor<bool>(m.input(), {false, false, false, //
|
||||||
|
false, false, false, //
|
||||||
|
false, false, false});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_EQ(m.GetOutput().size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WhereOpTest, SelectFromMatrix1) {
|
||||||
|
IntegerWhereOpModel m({TensorType_BOOL, {3, 1}}, {TensorType_INT32, {}});
|
||||||
|
m.PopulateTensor<bool>(m.input(), {true, false, true});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, //
|
||||||
|
2, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WhereOpTest, SelectFromMatrix2) {
|
||||||
|
IntegerWhereOpModel m({TensorType_BOOL, {3, 3}}, {TensorType_INT32, {}});
|
||||||
|
m.PopulateTensor<bool>(m.input(), {true, true, false, //
|
||||||
|
true, false, false, //
|
||||||
|
true, false, true});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, //
|
||||||
|
0, 1, //
|
||||||
|
1, 0, //
|
||||||
|
2, 0, //
|
||||||
|
2, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WhereOpTest, SelectFromMatrix3) {
|
||||||
|
IntegerWhereOpModel m({TensorType_BOOL, {3, 5}}, {TensorType_INT32, {}});
|
||||||
|
m.PopulateTensor<bool>(m.input(), {true, false, false, true, true, //
|
||||||
|
false, true, true, false, false, //
|
||||||
|
true, false, true, false, false});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, //
|
||||||
|
0, 3, //
|
||||||
|
0, 4, //
|
||||||
|
1, 1, //
|
||||||
|
1, 2, //
|
||||||
|
2, 0, //
|
||||||
|
2, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WhereOpTest, SelectFromRank3TensorNoResult) {
|
||||||
|
IntegerWhereOpModel m({TensorType_BOOL, {2, 2, 2}}, {TensorType_INT32, {}});
|
||||||
|
m.PopulateTensor<bool>(m.input(), {false, false, false, false, //
|
||||||
|
false, false, false, false});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_EQ(m.GetOutput().size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WhereOpTest, SelectFromRank3Tensor1) {
|
||||||
|
IntegerWhereOpModel m({TensorType_BOOL, {2, 1, 3}}, {TensorType_INT32, {}});
|
||||||
|
m.PopulateTensor<bool>(m.input(), {true, false, true, //
|
||||||
|
false, false, true});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, //
|
||||||
|
0, 0, 2, //
|
||||||
|
1, 0, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WhereOpTest, SelectFromRank3Tensor2) {
|
||||||
|
IntegerWhereOpModel m({TensorType_BOOL, {2, 2, 2}}, {TensorType_INT32, {}});
|
||||||
|
m.PopulateTensor<bool>(m.input(), {true, true, false, true, //
|
||||||
|
false, false, true, true});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, //
|
||||||
|
0, 0, 1, //
|
||||||
|
0, 1, 1, //
|
||||||
|
1, 1, 0, //
|
||||||
|
1, 1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WhereOpTest, SelectFromRank3Tensor3) {
|
||||||
|
IntegerWhereOpModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_INT32, {}});
|
||||||
|
m.PopulateTensor<bool>(m.input(), {true, true, false, true, false, false, //
|
||||||
|
false, false, true, false, true, true});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, //
|
||||||
|
0, 0, 1, //
|
||||||
|
0, 1, 1, //
|
||||||
|
1, 1, 0, //
|
||||||
|
1, 2, 0, //
|
||||||
|
1, 2, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
::tflite::LogToStderr();
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
@ -2514,6 +2514,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
|||||||
{"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn},
|
{"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn},
|
||||||
{"MirrorPad", ConvertMirrorPadOperator},
|
{"MirrorPad", ConvertMirrorPadOperator},
|
||||||
{"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
|
{"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
|
||||||
|
{"Where", ConvertSimpleOperator<WhereOperator, 1, 1>},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,7 +165,8 @@ enum class OperatorType : uint8 {
|
|||||||
kBidirectionalSequenceLstm,
|
kBidirectionalSequenceLstm,
|
||||||
kReverseV2,
|
kReverseV2,
|
||||||
kBidirectionalSequenceRnn,
|
kBidirectionalSequenceRnn,
|
||||||
kGatherNd
|
kGatherNd,
|
||||||
|
kWhere
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||||
@ -2036,6 +2037,18 @@ struct UnidirectionalSequenceRnnOperator : Operator {
|
|||||||
FusedActivationFunctionType fused_activation_function;
|
FusedActivationFunctionType fused_activation_function;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Where Operator:
|
||||||
|
// Return the coordinates of the true values in condition tensor in row-major
|
||||||
|
// order.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// inputs[0]: required: boolean condition tensor
|
||||||
|
//
|
||||||
|
// TensorFlow equivalent: Where
|
||||||
|
struct WhereOperator : Operator {
|
||||||
|
WhereOperator() : Operator(OperatorType::kWhere) {}
|
||||||
|
};
|
||||||
|
|
||||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||||
|
@ -1917,6 +1917,25 @@ class UnidirectionalSequenceRnn
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
|
||||||
|
::tflite::BuiltinOptions_WhereOptions> {
|
||||||
|
public:
|
||||||
|
using BuiltinOperator::BuiltinOperator;
|
||||||
|
|
||||||
|
flatbuffers::Offset<TfLiteOptions> WriteOptions(
|
||||||
|
const TocoOperator& op,
|
||||||
|
flatbuffers::FlatBufferBuilder* builder) const override {
|
||||||
|
return ::tflite::CreateWhereOptions(*builder);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReadOptions(const TfLiteOptions& options,
|
||||||
|
TocoOperator* op) const override {}
|
||||||
|
|
||||||
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
|
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
|
||||||
const string& tensorflow_node_def) {
|
const string& tensorflow_node_def) {
|
||||||
auto fbb = absl::make_unique<flexbuffers::Builder>();
|
auto fbb = absl::make_unique<flexbuffers::Builder>();
|
||||||
@ -2398,6 +2417,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
|||||||
ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
|
ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
|
||||||
::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
|
::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
|
||||||
OperatorType::kUnidirectionalSequenceRnn));
|
OperatorType::kUnidirectionalSequenceRnn));
|
||||||
|
ops.push_back(
|
||||||
|
MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
|
||||||
|
|
||||||
// Custom Operators.
|
// Custom Operators.
|
||||||
ops.push_back(
|
ops.push_back(
|
||||||
|
@ -248,6 +248,13 @@ TEST_F(OperatorTest, BuiltinGatherNd) {
|
|||||||
ASSERT_NE(output_toco_op.get(), nullptr);
|
ASSERT_NE(output_toco_op.get(), nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, BuiltinWhere) {
|
||||||
|
WhereOperator op;
|
||||||
|
auto output_toco_op =
|
||||||
|
SerializeAndDeserialize(GetOperator("WHERE", OperatorType::kWhere), op);
|
||||||
|
ASSERT_NE(output_toco_op.get(), nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(OperatorTest, BuiltinL2Pool) {
|
TEST_F(OperatorTest, BuiltinL2Pool) {
|
||||||
L2PoolOperator op;
|
L2PoolOperator op;
|
||||||
op.stride_width = 123;
|
op.stride_width = 123;
|
||||||
|
@ -424,6 +424,7 @@ const char* OperatorTypeName(OperatorType type) {
|
|||||||
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
|
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
|
HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
|
||||||
HANDLE_OPERATORTYPENAME_CASE(Cos)
|
HANDLE_OPERATORTYPENAME_CASE(Cos)
|
||||||
|
HANDLE_OPERATORTYPENAME_CASE(Where)
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unhandled op type";
|
LOG(FATAL) << "Unhandled op type";
|
||||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||||
|
Loading…
Reference in New Issue
Block a user