Implement tf.where.

PiperOrigin-RevId: 234398807
This commit is contained in:
Haoliang Zhang 2019-02-17 16:00:41 -08:00 committed by TensorFlower Gardener
parent 439f3eb035
commit d6ae732706
11 changed files with 370 additions and 1 deletions

View File

@ -1023,6 +1023,22 @@ Outputs {
} }
``` ```
**WHERE**
```
Inputs {
0: A tensor of type bool.
1: A tensor which may have the same shape as condition. If condition is rank
1, x may have higher rank, but its first dimension must match the size of
condition.
2: A tensor with the same shape and type as x.
}
Outputs {
0: A tensor with the same type and shape as x, y if they are non-None, or
a tensor with shape (num_true, dim_size(condition)).
}
```
**ZEROS_LIKE** **ZEROS_LIKE**
``` ```

View File

@ -227,6 +227,7 @@ cc_library(
"unidirectional_sequence_rnn.cc", "unidirectional_sequence_rnn.cc",
"unique.cc", "unique.cc",
"unpack.cc", "unpack.cc",
"where.cc",
"while.cc", "while.cc",
"zeros_like.cc", "zeros_like.cc",
], ],
@ -1187,6 +1188,19 @@ tf_cc_test(
], ],
) )
tf_cc_test(
name = "where_test",
size = "small",
srcs = ["where_test.cc"],
deps = [
":builtin_ops",
"//tensorflow/lite:builtin_op_data",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
tf_cc_test( tf_cc_test(
name = "zeros_like_test", name = "zeros_like_test",
size = "small", size = "small",

View File

@ -4586,6 +4586,34 @@ void RankOneSelect(const RuntimeShape& input_condition_shape,
} }
} }
template <typename D, typename T>
void SelectTrueCoords(const RuntimeShape& input_condition_shape,
const D* input_condition_data, T* output_data) {
const size_t size = input_condition_shape.FlatSize();
const size_t cond_rank = input_condition_shape.DimensionsCount();
std::vector<int> dims_to_count(cond_rank, 0);
int cur_flat_size = size;
for (int i = 0; i < cond_rank; ++i) {
dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
cur_flat_size = dims_to_count[i];
}
int output_index = 0;
for (int i = 0; i < size; ++i) {
if (input_condition_data[i]) {
// Insert the coordinate of the current item (row major) into output.
int flat_index = i;
for (int j = 0; j < cond_rank; ++j) {
int coord_j = flat_index / dims_to_count[j];
output_data[output_index * cond_rank + j] = coord_j;
flat_index %= dims_to_count[j];
}
output_index++;
}
}
}
// For easy implementation, the indices is always a vector of size-4 vectors. // For easy implementation, the indices is always a vector of size-4 vectors.
template <typename T, typename TI> template <typename T, typename TI>
inline void SparseToDense(const std::vector<std::vector<TI>>& indices, inline void SparseToDense(const std::vector<std::vector<TI>>& indices,

View File

@ -135,6 +135,7 @@ TfLiteRegistration* Register_UNIQUE();
TfLiteRegistration* Register_REVERSE_V2(); TfLiteRegistration* Register_REVERSE_V2();
TfLiteRegistration* Register_ADD_N(); TfLiteRegistration* Register_ADD_N();
TfLiteRegistration* Register_GATHER_ND(); TfLiteRegistration* Register_GATHER_ND();
TfLiteRegistration* Register_WHERE();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError( context->ReportError(
@ -361,6 +362,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2()); AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2());
AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N()); AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND()); AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND());
AddBuiltin(BuiltinOperator_WHERE, Register_WHERE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default. // custom ops aren't always included by default.

View File

@ -0,0 +1,105 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace where {
constexpr int kInputConditionTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
const TfLiteTensor* cond_tensor,
TfLiteTensor* output_tensor) {
// Output tensor should have shape:
// (num_true, cond_rank), where num_true denotes the number of true values
// in condition.
const RuntimeShape& cond_shape = GetTensorShape(cond_tensor);
const int size = cond_shape.FlatSize();
const int cond_rank = cond_shape.DimensionsCount();
const bool* cond_data = GetTensorData<bool>(cond_tensor);
int true_count = 0;
for (int i = 0; i < size; ++i) {
if (cond_data[i]) {
true_count++;
}
}
TfLiteIntArray* output_dims = TfLiteIntArrayCreate(2);
output_dims->data[0] = true_count;
output_dims->data[1] = cond_rank;
return context->ResizeTensor(context, output_tensor, output_dims);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* cond_tensor =
GetInput(context, node, kInputConditionTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (cond_tensor->type != kTfLiteBool) {
context->ReportError(context,
"Condition tensor must be of type bool, but saw '%s'.",
TfLiteTypeGetName(cond_tensor->type));
return kTfLiteError;
}
// As output will be a 2D tensor of indices, we use int32 as data type.
output->type = kTfLiteInt32;
// Exit early if cond is a non-const tensor. Set output tensor to dynamic so
// output size can be determined in Eval.
if (!IsConstantTensor(cond_tensor)) {
SetTensorToDynamic(output);
return kTfLiteOk;
}
return ResizeOutputTensor(context, cond_tensor, output);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* cond_tensor =
GetInput(context, node, kInputConditionTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,
ResizeOutputTensor(context, cond_tensor, output));
}
reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
GetTensorData<bool>(cond_tensor),
GetTensorData<int32_t>(output));
return kTfLiteOk;
}
} // namespace where
TfLiteRegistration* Register_WHERE() {
static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
where::Prepare, where::Eval};
return &r;
}
} // namespace builtin
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,161 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
namespace tflite {
namespace {
using ::testing::ElementsAreArray;
class BaseWhereOpModel : public SingleOpModel {
public:
BaseWhereOpModel(const TensorData& input, const TensorData& output) {
input_ = AddInput(input);
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_WHERE, BuiltinOptions_WhereOptions,
CreateWhereOptions(builder_).Union());
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
protected:
int input_;
int output_;
};
class IntegerWhereOpModel : public BaseWhereOpModel {
public:
using BaseWhereOpModel::BaseWhereOpModel;
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
};
TEST(WhereOpTest, SelectFromVectorNoResult) {
IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT32, {}});
m.PopulateTensor<bool>(m.input(), {false, false, false});
m.Invoke();
EXPECT_THAT(m.GetOutput().size(), 0);
}
TEST(WhereOpTest, SelectFromVector) {
IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT32, {}});
m.PopulateTensor<bool>(m.input(), {true, false, true});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2}));
}
TEST(WhereOpTest, SelectFromMatrixNoResult) {
IntegerWhereOpModel m({TensorType_BOOL, {3, 3}}, {TensorType_INT32, {}});
m.PopulateTensor<bool>(m.input(), {false, false, false, //
false, false, false, //
false, false, false});
m.Invoke();
EXPECT_EQ(m.GetOutput().size(), 0);
}
TEST(WhereOpTest, SelectFromMatrix1) {
IntegerWhereOpModel m({TensorType_BOOL, {3, 1}}, {TensorType_INT32, {}});
m.PopulateTensor<bool>(m.input(), {true, false, true});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, //
2, 0}));
}
TEST(WhereOpTest, SelectFromMatrix2) {
IntegerWhereOpModel m({TensorType_BOOL, {3, 3}}, {TensorType_INT32, {}});
m.PopulateTensor<bool>(m.input(), {true, true, false, //
true, false, false, //
true, false, true});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, //
0, 1, //
1, 0, //
2, 0, //
2, 2}));
}
TEST(WhereOpTest, SelectFromMatrix3) {
IntegerWhereOpModel m({TensorType_BOOL, {3, 5}}, {TensorType_INT32, {}});
m.PopulateTensor<bool>(m.input(), {true, false, false, true, true, //
false, true, true, false, false, //
true, false, true, false, false});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, //
0, 3, //
0, 4, //
1, 1, //
1, 2, //
2, 0, //
2, 2}));
}
TEST(WhereOpTest, SelectFromRank3TensorNoResult) {
IntegerWhereOpModel m({TensorType_BOOL, {2, 2, 2}}, {TensorType_INT32, {}});
m.PopulateTensor<bool>(m.input(), {false, false, false, false, //
false, false, false, false});
m.Invoke();
EXPECT_EQ(m.GetOutput().size(), 0);
}
TEST(WhereOpTest, SelectFromRank3Tensor1) {
IntegerWhereOpModel m({TensorType_BOOL, {2, 1, 3}}, {TensorType_INT32, {}});
m.PopulateTensor<bool>(m.input(), {true, false, true, //
false, false, true});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, //
0, 0, 2, //
1, 0, 2}));
}
TEST(WhereOpTest, SelectFromRank3Tensor2) {
IntegerWhereOpModel m({TensorType_BOOL, {2, 2, 2}}, {TensorType_INT32, {}});
m.PopulateTensor<bool>(m.input(), {true, true, false, true, //
false, false, true, true});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, //
0, 0, 1, //
0, 1, 1, //
1, 1, 0, //
1, 1, 1}));
}
TEST(WhereOpTest, SelectFromRank3Tensor3) {
IntegerWhereOpModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_INT32, {}});
m.PopulateTensor<bool>(m.input(), {true, true, false, true, false, false, //
false, false, true, false, true, true});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, //
0, 0, 1, //
0, 1, 1, //
1, 1, 0, //
1, 2, 0, //
1, 2, 1}));
}
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -2514,6 +2514,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn}, {"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn},
{"MirrorPad", ConvertMirrorPadOperator}, {"MirrorPad", ConvertMirrorPadOperator},
{"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>}, {"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
{"Where", ConvertSimpleOperator<WhereOperator, 1, 1>},
}); });
} }

View File

@ -165,7 +165,8 @@ enum class OperatorType : uint8 {
kBidirectionalSequenceLstm, kBidirectionalSequenceLstm,
kReverseV2, kReverseV2,
kBidirectionalSequenceRnn, kBidirectionalSequenceRnn,
kGatherNd kGatherNd,
kWhere
}; };
// Helper to deal with TensorFlow arrays using a different ordering of // Helper to deal with TensorFlow arrays using a different ordering of
@ -2036,6 +2037,18 @@ struct UnidirectionalSequenceRnnOperator : Operator {
FusedActivationFunctionType fused_activation_function; FusedActivationFunctionType fused_activation_function;
}; };
// Where Operator:
// Return the coordinates of the true values in condition tensor in row-major
// order.
//
// Inputs:
// inputs[0]: required: boolean condition tensor
//
// TensorFlow equivalent: Where
struct WhereOperator : Operator {
WhereOperator() : Operator(OperatorType::kWhere) {}
};
// Alloc's are used for transient arrays only. An Alloc specifies which interval // Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to // of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are // be used for the transient array at hand. The 'start' and 'end' values are

View File

@ -1917,6 +1917,25 @@ class UnidirectionalSequenceRnn
} }
}; };
class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
::tflite::BuiltinOptions_WhereOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateWhereOptions(*builder);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {}
int GetVersion(const OperatorSignature& op_signature) const override {
return 1;
}
};
std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions( std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
const string& tensorflow_node_def) { const string& tensorflow_node_def) {
auto fbb = absl::make_unique<flexbuffers::Builder>(); auto fbb = absl::make_unique<flexbuffers::Builder>();
@ -2398,6 +2417,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
ops.push_back(MakeUnique<UnidirectionalSequenceRnn>( ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
OperatorType::kUnidirectionalSequenceRnn)); OperatorType::kUnidirectionalSequenceRnn));
ops.push_back(
MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
// Custom Operators. // Custom Operators.
ops.push_back( ops.push_back(

View File

@ -248,6 +248,13 @@ TEST_F(OperatorTest, BuiltinGatherNd) {
ASSERT_NE(output_toco_op.get(), nullptr); ASSERT_NE(output_toco_op.get(), nullptr);
} }
TEST_F(OperatorTest, BuiltinWhere) {
WhereOperator op;
auto output_toco_op =
SerializeAndDeserialize(GetOperator("WHERE", OperatorType::kWhere), op);
ASSERT_NE(output_toco_op.get(), nullptr);
}
TEST_F(OperatorTest, BuiltinL2Pool) { TEST_F(OperatorTest, BuiltinL2Pool) {
L2PoolOperator op; L2PoolOperator op;
op.stride_width = 123; op.stride_width = 123;

View File

@ -424,6 +424,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn) HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn)
HANDLE_OPERATORTYPENAME_CASE(ReverseV2) HANDLE_OPERATORTYPENAME_CASE(ReverseV2)
HANDLE_OPERATORTYPENAME_CASE(Cos) HANDLE_OPERATORTYPENAME_CASE(Cos)
HANDLE_OPERATORTYPENAME_CASE(Where)
default: default:
LOG(FATAL) << "Unhandled op type"; LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE #undef HANDLE_OPERATORTYPENAME_CASE