Implement tf.where.
PiperOrigin-RevId: 234398807
This commit is contained in:
parent
439f3eb035
commit
d6ae732706
@ -1023,6 +1023,22 @@ Outputs {
|
||||
}
|
||||
```
|
||||
|
||||
**WHERE**
|
||||
|
||||
```
|
||||
Inputs {
|
||||
0: A tensor of type bool.
|
||||
1: A tensor which may have the same shape as condition. If condition is rank
|
||||
1, x may have higher rank, but its first dimension must match the size of
|
||||
condition.
|
||||
2: A tensor with the same shape and type as x.
|
||||
}
|
||||
Outputs {
|
||||
0: A tensor with the same type and shape as x, y if they are non-None, or
|
||||
a tensor with shape (num_true, dim_size(condition)).
|
||||
}
|
||||
```
|
||||
|
||||
**ZEROS_LIKE**
|
||||
|
||||
```
|
||||
|
@ -227,6 +227,7 @@ cc_library(
|
||||
"unidirectional_sequence_rnn.cc",
|
||||
"unique.cc",
|
||||
"unpack.cc",
|
||||
"where.cc",
|
||||
"while.cc",
|
||||
"zeros_like.cc",
|
||||
],
|
||||
@ -1187,6 +1188,19 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "where_test",
|
||||
size = "small",
|
||||
srcs = ["where_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
"//tensorflow/lite:builtin_op_data",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "zeros_like_test",
|
||||
size = "small",
|
||||
|
@ -4586,6 +4586,34 @@ void RankOneSelect(const RuntimeShape& input_condition_shape,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename D, typename T>
|
||||
void SelectTrueCoords(const RuntimeShape& input_condition_shape,
|
||||
const D* input_condition_data, T* output_data) {
|
||||
const size_t size = input_condition_shape.FlatSize();
|
||||
const size_t cond_rank = input_condition_shape.DimensionsCount();
|
||||
|
||||
std::vector<int> dims_to_count(cond_rank, 0);
|
||||
int cur_flat_size = size;
|
||||
for (int i = 0; i < cond_rank; ++i) {
|
||||
dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
|
||||
cur_flat_size = dims_to_count[i];
|
||||
}
|
||||
|
||||
int output_index = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (input_condition_data[i]) {
|
||||
// Insert the coordinate of the current item (row major) into output.
|
||||
int flat_index = i;
|
||||
for (int j = 0; j < cond_rank; ++j) {
|
||||
int coord_j = flat_index / dims_to_count[j];
|
||||
output_data[output_index * cond_rank + j] = coord_j;
|
||||
flat_index %= dims_to_count[j];
|
||||
}
|
||||
output_index++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For easy implementation, the indices is always a vector of size-4 vectors.
|
||||
template <typename T, typename TI>
|
||||
inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
|
||||
|
@ -135,6 +135,7 @@ TfLiteRegistration* Register_UNIQUE();
|
||||
TfLiteRegistration* Register_REVERSE_V2();
|
||||
TfLiteRegistration* Register_ADD_N();
|
||||
TfLiteRegistration* Register_GATHER_ND();
|
||||
TfLiteRegistration* Register_WHERE();
|
||||
|
||||
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
|
||||
context->ReportError(
|
||||
@ -361,6 +362,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2());
|
||||
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
|
||||
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND());
|
||||
AddBuiltin(BuiltinOperator_WHERE, Register_WHERE());
|
||||
|
||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||
// custom ops aren't always included by default.
|
||||
|
105
tensorflow/lite/kernels/where.cc
Normal file
105
tensorflow/lite/kernels/where.cc
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/c/c_api_internal.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace where {
|
||||
|
||||
constexpr int kInputConditionTensor = 0;
|
||||
constexpr int kOutputTensor = 0;
|
||||
|
||||
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
const TfLiteTensor* cond_tensor,
|
||||
TfLiteTensor* output_tensor) {
|
||||
// Output tensor should have shape:
|
||||
// (num_true, cond_rank), where num_true denotes the number of true values
|
||||
// in condition.
|
||||
const RuntimeShape& cond_shape = GetTensorShape(cond_tensor);
|
||||
const int size = cond_shape.FlatSize();
|
||||
const int cond_rank = cond_shape.DimensionsCount();
|
||||
const bool* cond_data = GetTensorData<bool>(cond_tensor);
|
||||
|
||||
int true_count = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (cond_data[i]) {
|
||||
true_count++;
|
||||
}
|
||||
}
|
||||
TfLiteIntArray* output_dims = TfLiteIntArrayCreate(2);
|
||||
output_dims->data[0] = true_count;
|
||||
output_dims->data[1] = cond_rank;
|
||||
return context->ResizeTensor(context, output_tensor, output_dims);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
const TfLiteTensor* cond_tensor =
|
||||
GetInput(context, node, kInputConditionTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
if (cond_tensor->type != kTfLiteBool) {
|
||||
context->ReportError(context,
|
||||
"Condition tensor must be of type bool, but saw '%s'.",
|
||||
TfLiteTypeGetName(cond_tensor->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// As output will be a 2D tensor of indices, we use int32 as data type.
|
||||
output->type = kTfLiteInt32;
|
||||
|
||||
// Exit early if cond is a non-const tensor. Set output tensor to dynamic so
|
||||
// output size can be determined in Eval.
|
||||
if (!IsConstantTensor(cond_tensor)) {
|
||||
SetTensorToDynamic(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
return ResizeOutputTensor(context, cond_tensor, output);
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* cond_tensor =
|
||||
GetInput(context, node, kInputConditionTensor);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
if (IsDynamicTensor(output)) {
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
ResizeOutputTensor(context, cond_tensor, output));
|
||||
}
|
||||
|
||||
reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
|
||||
GetTensorData<bool>(cond_tensor),
|
||||
GetTensorData<int32_t>(output));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace where
|
||||
|
||||
TfLiteRegistration* Register_WHERE() {
|
||||
static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
|
||||
where::Prepare, where::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
161
tensorflow/lite/kernels/where_test.cc
Normal file
161
tensorflow/lite/kernels/where_test.cc
Normal file
@ -0,0 +1,161 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
class BaseWhereOpModel : public SingleOpModel {
|
||||
public:
|
||||
BaseWhereOpModel(const TensorData& input, const TensorData& output) {
|
||||
input_ = AddInput(input);
|
||||
output_ = AddOutput(output);
|
||||
SetBuiltinOp(BuiltinOperator_WHERE, BuiltinOptions_WhereOptions,
|
||||
CreateWhereOptions(builder_).Union());
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
int input() { return input_; }
|
||||
|
||||
protected:
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
class IntegerWhereOpModel : public BaseWhereOpModel {
|
||||
public:
|
||||
using BaseWhereOpModel::BaseWhereOpModel;
|
||||
|
||||
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
|
||||
};
|
||||
|
||||
TEST(WhereOpTest, SelectFromVectorNoResult) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT32, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {false, false, false});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput().size(), 0);
|
||||
}
|
||||
|
||||
TEST(WhereOpTest, SelectFromVector) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT32, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {true, false, true});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2}));
|
||||
}
|
||||
|
||||
TEST(WhereOpTest, SelectFromMatrixNoResult) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {3, 3}}, {TensorType_INT32, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {false, false, false, //
|
||||
false, false, false, //
|
||||
false, false, false});
|
||||
m.Invoke();
|
||||
EXPECT_EQ(m.GetOutput().size(), 0);
|
||||
}
|
||||
|
||||
TEST(WhereOpTest, SelectFromMatrix1) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {3, 1}}, {TensorType_INT32, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {true, false, true});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, //
|
||||
2, 0}));
|
||||
}
|
||||
|
||||
TEST(WhereOpTest, SelectFromMatrix2) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {3, 3}}, {TensorType_INT32, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {true, true, false, //
|
||||
true, false, false, //
|
||||
true, false, true});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, //
|
||||
0, 1, //
|
||||
1, 0, //
|
||||
2, 0, //
|
||||
2, 2}));
|
||||
}
|
||||
|
||||
TEST(WhereOpTest, SelectFromMatrix3) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {3, 5}}, {TensorType_INT32, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {true, false, false, true, true, //
|
||||
false, true, true, false, false, //
|
||||
true, false, true, false, false});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, //
|
||||
0, 3, //
|
||||
0, 4, //
|
||||
1, 1, //
|
||||
1, 2, //
|
||||
2, 0, //
|
||||
2, 2}));
|
||||
}
|
||||
|
||||
TEST(WhereOpTest, SelectFromRank3TensorNoResult) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {2, 2, 2}}, {TensorType_INT32, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {false, false, false, false, //
|
||||
false, false, false, false});
|
||||
m.Invoke();
|
||||
EXPECT_EQ(m.GetOutput().size(), 0);
|
||||
}
|
||||
|
||||
TEST(WhereOpTest, SelectFromRank3Tensor1) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {2, 1, 3}}, {TensorType_INT32, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {true, false, true, //
|
||||
false, false, true});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, //
|
||||
0, 0, 2, //
|
||||
1, 0, 2}));
|
||||
}
|
||||
|
||||
TEST(WhereOpTest, SelectFromRank3Tensor2) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {2, 2, 2}}, {TensorType_INT32, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {true, true, false, true, //
|
||||
false, false, true, true});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, //
|
||||
0, 0, 1, //
|
||||
0, 1, 1, //
|
||||
1, 1, 0, //
|
||||
1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(WhereOpTest, SelectFromRank3Tensor3) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_INT32, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {true, true, false, true, false, false, //
|
||||
false, false, true, false, true, true});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, //
|
||||
0, 0, 1, //
|
||||
0, 1, 1, //
|
||||
1, 1, 0, //
|
||||
1, 2, 0, //
|
||||
1, 2, 1}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
::tflite::LogToStderr();
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
@ -2514,6 +2514,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
|
||||
{"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn},
|
||||
{"MirrorPad", ConvertMirrorPadOperator},
|
||||
{"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
|
||||
{"Where", ConvertSimpleOperator<WhereOperator, 1, 1>},
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -165,7 +165,8 @@ enum class OperatorType : uint8 {
|
||||
kBidirectionalSequenceLstm,
|
||||
kReverseV2,
|
||||
kBidirectionalSequenceRnn,
|
||||
kGatherNd
|
||||
kGatherNd,
|
||||
kWhere
|
||||
};
|
||||
|
||||
// Helper to deal with TensorFlow arrays using a different ordering of
|
||||
@ -2036,6 +2037,18 @@ struct UnidirectionalSequenceRnnOperator : Operator {
|
||||
FusedActivationFunctionType fused_activation_function;
|
||||
};
|
||||
|
||||
// Where Operator:
|
||||
// Return the coordinates of the true values in condition tensor in row-major
|
||||
// order.
|
||||
//
|
||||
// Inputs:
|
||||
// inputs[0]: required: boolean condition tensor
|
||||
//
|
||||
// TensorFlow equivalent: Where
|
||||
struct WhereOperator : Operator {
|
||||
WhereOperator() : Operator(OperatorType::kWhere) {}
|
||||
};
|
||||
|
||||
// Alloc's are used for transient arrays only. An Alloc specifies which interval
|
||||
// of the "transient_data" workspace buffer passed to inference functions, is to
|
||||
// be used for the transient array at hand. The 'start' and 'end' values are
|
||||
|
@ -1917,6 +1917,25 @@ class UnidirectionalSequenceRnn
|
||||
}
|
||||
};
|
||||
|
||||
class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
|
||||
::tflite::BuiltinOptions_WhereOptions> {
|
||||
public:
|
||||
using BuiltinOperator::BuiltinOperator;
|
||||
|
||||
flatbuffers::Offset<TfLiteOptions> WriteOptions(
|
||||
const TocoOperator& op,
|
||||
flatbuffers::FlatBufferBuilder* builder) const override {
|
||||
return ::tflite::CreateWhereOptions(*builder);
|
||||
}
|
||||
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
|
||||
const string& tensorflow_node_def) {
|
||||
auto fbb = absl::make_unique<flexbuffers::Builder>();
|
||||
@ -2398,6 +2417,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
|
||||
ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
|
||||
::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
|
||||
OperatorType::kUnidirectionalSequenceRnn));
|
||||
ops.push_back(
|
||||
MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
|
||||
|
||||
// Custom Operators.
|
||||
ops.push_back(
|
||||
|
@ -248,6 +248,13 @@ TEST_F(OperatorTest, BuiltinGatherNd) {
|
||||
ASSERT_NE(output_toco_op.get(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinWhere) {
|
||||
WhereOperator op;
|
||||
auto output_toco_op =
|
||||
SerializeAndDeserialize(GetOperator("WHERE", OperatorType::kWhere), op);
|
||||
ASSERT_NE(output_toco_op.get(), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinL2Pool) {
|
||||
L2PoolOperator op;
|
||||
op.stride_width = 123;
|
||||
|
@ -424,6 +424,7 @@ const char* OperatorTypeName(OperatorType type) {
|
||||
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
|
||||
HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Cos)
|
||||
HANDLE_OPERATORTYPENAME_CASE(Where)
|
||||
default:
|
||||
LOG(FATAL) << "Unhandled op type";
|
||||
#undef HANDLE_OPERATORTYPENAME_CASE
|
||||
|
Loading…
Reference in New Issue
Block a user