diff --git a/tensorflow/lite/g3doc/tf_ops_compatibility.md b/tensorflow/lite/g3doc/tf_ops_compatibility.md index d5b998df78a..4f5def97912 100644 --- a/tensorflow/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/lite/g3doc/tf_ops_compatibility.md @@ -1023,6 +1023,22 @@ Outputs { } ``` +**WHERE** + +``` +Inputs { + 0: A tensor of type bool. + 1: A tensor which may have the same shape as condition. If condition is rank + 1, x may have higher rank, but its first dimension must match the size of + condition. + 2: A tensor with the same shape and type as x. +} +Outputs { + 0: A tensor with the same type and shape as x, y if they are non-None, or + a tensor with shape (num_true, dim_size(condition)). +} +``` + **ZEROS_LIKE** ``` diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index bf7dfb59f41..c24b6ede630 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -227,6 +227,7 @@ cc_library( "unidirectional_sequence_rnn.cc", "unique.cc", "unpack.cc", + "where.cc", "while.cc", "zeros_like.cc", ], @@ -1187,6 +1188,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "where_test", + size = "small", + srcs = ["where_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/lite:builtin_op_data", + "//tensorflow/lite:framework", + "//tensorflow/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "zeros_like_test", size = "small", diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index 515db6fd37a..34763807fa5 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -4586,6 +4586,34 @@ void RankOneSelect(const RuntimeShape& input_condition_shape, } } +template +void SelectTrueCoords(const RuntimeShape& input_condition_shape, + const D* input_condition_data, T* output_data) { + const size_t size = input_condition_shape.FlatSize(); + const size_t cond_rank = input_condition_shape.DimensionsCount(); + + std::vector dims_to_count(cond_rank, 0); + int cur_flat_size = size; + for (int i = 0; i < cond_rank; ++i) { + dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i); + cur_flat_size = dims_to_count[i]; + } + + int output_index = 0; + for (int i = 0; i < size; ++i) { + if (input_condition_data[i]) { + // Insert the coordinate of the current item (row major) into output. + int flat_index = i; + for (int j = 0; j < cond_rank; ++j) { + int coord_j = flat_index / dims_to_count[j]; + output_data[output_index * cond_rank + j] = coord_j; + flat_index %= dims_to_count[j]; + } + output_index++; + } + } +} + // For easy implementation, the indices is always a vector of size-4 vectors. template inline void SparseToDense(const std::vector>& indices, diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index b4d1f112cd6..aa8cd2c4725 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -135,6 +135,7 @@ TfLiteRegistration* Register_UNIQUE(); TfLiteRegistration* Register_REVERSE_V2(); TfLiteRegistration* Register_ADD_N(); TfLiteRegistration* Register_GATHER_ND(); +TfLiteRegistration* Register_WHERE(); TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { context->ReportError( @@ -361,6 +362,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2()); AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N()); AddBuiltin(BuiltinOperator_GATHER_ND, Register_GATHER_ND()); + AddBuiltin(BuiltinOperator_WHERE, Register_WHERE()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/kernels/where.cc b/tensorflow/lite/kernels/where.cc new file mode 100644 index 00000000000..96ee36f08bc --- /dev/null +++ b/tensorflow/lite/kernels/where.cc @@ -0,0 +1,105 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace where { + +constexpr int kInputConditionTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + const TfLiteTensor* cond_tensor, + TfLiteTensor* output_tensor) { + // Output tensor should have shape: + // (num_true, cond_rank), where num_true denotes the number of true values + // in condition. + const RuntimeShape& cond_shape = GetTensorShape(cond_tensor); + const int size = cond_shape.FlatSize(); + const int cond_rank = cond_shape.DimensionsCount(); + const bool* cond_data = GetTensorData(cond_tensor); + + int true_count = 0; + for (int i = 0; i < size; ++i) { + if (cond_data[i]) { + true_count++; + } + } + TfLiteIntArray* output_dims = TfLiteIntArrayCreate(2); + output_dims->data[0] = true_count; + output_dims->data[1] = cond_rank; + return context->ResizeTensor(context, output_tensor, output_dims); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* cond_tensor = + GetInput(context, node, kInputConditionTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (cond_tensor->type != kTfLiteBool) { + context->ReportError(context, + "Condition tensor must be of type bool, but saw '%s'.", + TfLiteTypeGetName(cond_tensor->type)); + return kTfLiteError; + } + + // As output will be a 2D tensor of indices, we use int32 as data type. + output->type = kTfLiteInt32; + + // Exit early if cond is a non-const tensor. Set output tensor to dynamic so + // output size can be determined in Eval. + if (!IsConstantTensor(cond_tensor)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, cond_tensor, output); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* cond_tensor = + GetInput(context, node, kInputConditionTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputTensor(context, cond_tensor, output)); + } + + reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor), + GetTensorData(cond_tensor), + GetTensorData(output)); + return kTfLiteOk; +} +} // namespace where + +TfLiteRegistration* Register_WHERE() { + static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr, + where::Prepare, where::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/where_test.cc b/tensorflow/lite/kernels/where_test.cc new file mode 100644 index 00000000000..89bd7c43646 --- /dev/null +++ b/tensorflow/lite/kernels/where_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseWhereOpModel : public SingleOpModel { + public: + BaseWhereOpModel(const TensorData& input, const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_WHERE, BuiltinOptions_WhereOptions, + CreateWhereOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input() { return input_; } + + protected: + int input_; + int output_; +}; + +class IntegerWhereOpModel : public BaseWhereOpModel { + public: + using BaseWhereOpModel::BaseWhereOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(WhereOpTest, SelectFromVectorNoResult) { + IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {false, false, false}); + m.Invoke(); + EXPECT_THAT(m.GetOutput().size(), 0); +} + +TEST(WhereOpTest, SelectFromVector) { + IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, false, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2})); +} + +TEST(WhereOpTest, SelectFromMatrixNoResult) { + IntegerWhereOpModel m({TensorType_BOOL, {3, 3}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {false, false, false, // + false, false, false, // + false, false, false}); + m.Invoke(); + EXPECT_EQ(m.GetOutput().size(), 0); +} + +TEST(WhereOpTest, SelectFromMatrix1) { + IntegerWhereOpModel m({TensorType_BOOL, {3, 1}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, false, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, // + 2, 0})); +} + +TEST(WhereOpTest, SelectFromMatrix2) { + IntegerWhereOpModel m({TensorType_BOOL, {3, 3}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, true, false, // + true, false, false, // + true, false, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, // + 0, 1, // + 1, 0, // + 2, 0, // + 2, 2})); +} + +TEST(WhereOpTest, SelectFromMatrix3) { + IntegerWhereOpModel m({TensorType_BOOL, {3, 5}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, false, false, true, true, // + false, true, true, false, false, // + true, false, true, false, false}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, // + 0, 3, // + 0, 4, // + 1, 1, // + 1, 2, // + 2, 0, // + 2, 2})); +} + +TEST(WhereOpTest, SelectFromRank3TensorNoResult) { + IntegerWhereOpModel m({TensorType_BOOL, {2, 2, 2}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {false, false, false, false, // + false, false, false, false}); + m.Invoke(); + EXPECT_EQ(m.GetOutput().size(), 0); +} + +TEST(WhereOpTest, SelectFromRank3Tensor1) { + IntegerWhereOpModel m({TensorType_BOOL, {2, 1, 3}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, false, true, // + false, false, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, // + 0, 0, 2, // + 1, 0, 2})); +} + +TEST(WhereOpTest, SelectFromRank3Tensor2) { + IntegerWhereOpModel m({TensorType_BOOL, {2, 2, 2}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, true, false, true, // + false, false, true, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, // + 0, 0, 1, // + 0, 1, 1, // + 1, 1, 0, // + 1, 1, 1})); +} + +TEST(WhereOpTest, SelectFromRank3Tensor3) { + IntegerWhereOpModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_INT32, {}}); + m.PopulateTensor(m.input(), {true, true, false, true, false, false, // + false, false, true, false, true, true}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, // + 0, 0, 1, // + 0, 1, 1, // + 1, 1, 0, // + 1, 2, 0, // + 1, 2, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc index deeda8229e5..eb2892479ff 100644 --- a/tensorflow/lite/toco/import_tensorflow.cc +++ b/tensorflow/lite/toco/import_tensorflow.cc @@ -2514,6 +2514,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn}, {"MirrorPad", ConvertMirrorPadOperator}, {"Unique", ConvertSimpleOperator}, + {"Where", ConvertSimpleOperator}, }); } diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h index 63911899aeb..e38f50e40f2 100644 --- a/tensorflow/lite/toco/model.h +++ b/tensorflow/lite/toco/model.h @@ -165,7 +165,8 @@ enum class OperatorType : uint8 { kBidirectionalSequenceLstm, kReverseV2, kBidirectionalSequenceRnn, - kGatherNd + kGatherNd, + kWhere }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -2036,6 +2037,18 @@ struct UnidirectionalSequenceRnnOperator : Operator { FusedActivationFunctionType fused_activation_function; }; +// Where Operator: +// Return the coordinates of the true values in condition tensor in row-major +// order. +// +// Inputs: +// inputs[0]: required: boolean condition tensor +// +// TensorFlow equivalent: Where +struct WhereOperator : Operator { + WhereOperator() : Operator(OperatorType::kWhere) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index e7dbe341618..8ab5f1b37e8 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -1917,6 +1917,25 @@ class UnidirectionalSequenceRnn } }; +class Where : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateWhereOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + + int GetVersion(const OperatorSignature& op_signature) const override { + return 1; + } +}; + std::unique_ptr WriteFlexOpOptions( const string& tensorflow_node_def) { auto fbb = absl::make_unique(); @@ -2398,6 +2417,8 @@ std::vector> BuildOperatorList( ops.push_back(MakeUnique( ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, OperatorType::kUnidirectionalSequenceRnn)); + ops.push_back( + MakeUnique(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere)); // Custom Operators. ops.push_back( diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index f898cc5bc46..361a72ecf90 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -248,6 +248,13 @@ TEST_F(OperatorTest, BuiltinGatherNd) { ASSERT_NE(output_toco_op.get(), nullptr); } +TEST_F(OperatorTest, BuiltinWhere) { + WhereOperator op; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("WHERE", OperatorType::kWhere), op); + ASSERT_NE(output_toco_op.get(), nullptr); +} + TEST_F(OperatorTest, BuiltinL2Pool) { L2PoolOperator op; op.stride_width = 123; diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc index 42d5d63d459..ccd8008eac1 100644 --- a/tensorflow/lite/toco/tooling_util.cc +++ b/tensorflow/lite/toco/tooling_util.cc @@ -424,6 +424,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceRnn) HANDLE_OPERATORTYPENAME_CASE(ReverseV2) HANDLE_OPERATORTYPENAME_CASE(Cos) + HANDLE_OPERATORTYPENAME_CASE(Where) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE