From c59af2913aaec235d883f50428efef1086f4c0e6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 Sep 2020 11:37:44 -0700 Subject: [PATCH] Moving common with TFLite ragged tensor code to a separate file. PiperOrigin-RevId: 329754601 Change-Id: Id7695bad73d18a8f54fdff7a8e1898c37477f13f --- tensorflow/core/util/BUILD | 18 ++++- tensorflow/core/util/ragged_to_dense_util.cc | 50 ++----------- tensorflow/core/util/ragged_to_dense_util.h | 13 ++-- .../core/util/ragged_to_dense_util_common.cc | 70 +++++++++++++++++++ .../core/util/ragged_to_dense_util_common.h | 40 +++++++++++ 5 files changed, 138 insertions(+), 53 deletions(-) create mode 100644 tensorflow/core/util/ragged_to_dense_util_common.cc create mode 100644 tensorflow/core/util/ragged_to_dense_util_common.h diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index 13e31e9be05..23564df0381 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -16,8 +16,8 @@ load( "tf_cuda_only_cc_test", "tf_kernel_library", ) -load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule") load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable") +load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule") load( "//third_party/mkl:build_defs.bzl", "mkl_deps", @@ -106,6 +106,8 @@ filegroup( "ptr_util.h", "ragged_to_dense_util.cc", "ragged_to_dense_util.h", + "ragged_to_dense_util_common.cc", + "ragged_to_dense_util_common.h", "reffed_status_callback.h", "saved_tensor_slice_util.cc", "saved_tensor_slice_util.h", @@ -383,6 +385,19 @@ cc_library( ], ) +cc_library( + name = "ragged_to_dense_util_common", + srcs = [ + "ragged_to_dense_util_common.cc", + ], + hdrs = [ + "ragged_to_dense_util_common.h", + ], + visibility = [ + "//visibility:public", + ], +) + cc_library( name = "ragged_to_dense_util", srcs = [ @@ -392,6 +407,7 @@ cc_library( "ragged_to_dense_util.h", ], deps = [ + ":ragged_to_dense_util_common", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/core/util/ragged_to_dense_util.cc b/tensorflow/core/util/ragged_to_dense_util.cc index cd95b5ec75b..1d00a43a14a 100644 --- a/tensorflow/core/util/ragged_to_dense_util.cc +++ b/tensorflow/core/util/ragged_to_dense_util.cc @@ -24,43 +24,15 @@ namespace tensorflow { using errors::InvalidArgument; -string RowPartitionTypeToString(RowPartitionType row_partition_type) { - switch (row_partition_type) { - case RowPartitionType::FIRST_DIM_SIZE: - return "FIRST_DIM_SIZE"; - case RowPartitionType::VALUE_ROWIDS: - return "VALUE_ROWIDS"; - case RowPartitionType::ROW_LENGTHS: - return "ROW_LENGTHS"; - case RowPartitionType::ROW_SPLITS: - return "ROW_SPLITS"; - case RowPartitionType::ROW_LIMITS: - return "ROW_LIMITS"; - case RowPartitionType::ROW_STARTS: - return "ROW_STARTS"; - default: - return "UNKNOWN ROW PARTITION TYPE"; - } -} tensorflow::Status GetRowPartitionTypesHelper( const std::vector& row_partition_type_strings, std::vector* row_partition_types) { - static const auto kStringToType = - new std::unordered_map( - {{"FIRST_DIM_SIZE", RowPartitionType::FIRST_DIM_SIZE}, - {"VALUE_ROWIDS", RowPartitionType::VALUE_ROWIDS}, - {"ROW_LENGTHS", RowPartitionType::ROW_LENGTHS}, - {"ROW_SPLITS", RowPartitionType::ROW_SPLITS}, - {"ROW_LIMITS", RowPartitionType::ROW_LIMITS}, - {"ROW_STARTS", RowPartitionType::ROW_STARTS}}); - - for (const string& type_str : row_partition_type_strings) { - const auto iter = kStringToType->find(type_str); - if (iter == kStringToType->end()) { - return InvalidArgument("Unknown string for partition info type: ", - type_str); - } - row_partition_types->push_back(iter->second); + *row_partition_types = GetRowPartitionTypesHelper(row_partition_type_strings); + if (row_partition_types->size() != row_partition_type_strings.size()) { + // Something was not converted, return error status. + return InvalidArgument( + "Unknown string for partition info type: ", + row_partition_type_strings.at(row_partition_types->size())); } return tensorflow::Status::OK(); } @@ -120,16 +92,6 @@ tensorflow::Status CombineRaggedTensorToTensorShapes( return tensorflow::Status::OK(); } -int GetRaggedRank(const std::vector& row_partition_types) { - if (row_partition_types.empty()) { - return 0; - } - if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE) { - return row_partition_types.size() - 1; - } - return row_partition_types.size(); -} - tensorflow::Status ValidateDefaultValueShape( const TensorShapeProto& default_value_shape, const TensorShapeProto& value_shape) { diff --git a/tensorflow/core/util/ragged_to_dense_util.h b/tensorflow/core/util/ragged_to_dense_util.h index d29d6a5b62d..28d230aa4a9 100644 --- a/tensorflow/core/util/ragged_to_dense_util.h +++ b/tensorflow/core/util/ragged_to_dense_util.h @@ -20,16 +20,9 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/util/ragged_to_dense_util_common.h" namespace tensorflow { -enum class RowPartitionType { - FIRST_DIM_SIZE, - VALUE_ROWIDS, - ROW_LENGTHS, - ROW_SPLITS, - ROW_LIMITS, - ROW_STARTS -}; string RowPartitionTypeToString(RowPartitionType row_partition_type); @@ -48,6 +41,10 @@ Status GetRowPartitionTypes( row_partition_types); } +Status GetRowPartitionTypesHelper( + const std::vector& row_partition_type_strings, + std::vector* row_partition_types); + Status CombineRaggedTensorToTensorShapes(int ragged_rank, const TensorShapeProto& shape, const TensorShapeProto& value_shape, diff --git a/tensorflow/core/util/ragged_to_dense_util_common.cc b/tensorflow/core/util/ragged_to_dense_util_common.cc new file mode 100644 index 00000000000..b2d0b2d2fd9 --- /dev/null +++ b/tensorflow/core/util/ragged_to_dense_util_common.cc @@ -0,0 +1,70 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/ragged_to_dense_util_common.h" + +#include + +namespace tensorflow { +std::string RowPartitionTypeToString(RowPartitionType row_partition_type) { + switch (row_partition_type) { + case RowPartitionType::FIRST_DIM_SIZE: + return "FIRST_DIM_SIZE"; + case RowPartitionType::VALUE_ROWIDS: + return "VALUE_ROWIDS"; + case RowPartitionType::ROW_LENGTHS: + return "ROW_LENGTHS"; + case RowPartitionType::ROW_SPLITS: + return "ROW_SPLITS"; + case RowPartitionType::ROW_LIMITS: + return "ROW_LIMITS"; + case RowPartitionType::ROW_STARTS: + return "ROW_STARTS"; + default: + return "UNKNOWN ROW PARTITION TYPE"; + } +} + +std::vector GetRowPartitionTypesHelper( + const std::vector& row_partition_type_strings) { + static const auto kStringToType = + new std::unordered_map( + {{"FIRST_DIM_SIZE", RowPartitionType::FIRST_DIM_SIZE}, + {"VALUE_ROWIDS", RowPartitionType::VALUE_ROWIDS}, + {"ROW_LENGTHS", RowPartitionType::ROW_LENGTHS}, + {"ROW_SPLITS", RowPartitionType::ROW_SPLITS}, + {"ROW_LIMITS", RowPartitionType::ROW_LIMITS}, + {"ROW_STARTS", RowPartitionType::ROW_STARTS}}); + std::vector result; + for (const auto& type_str : row_partition_type_strings) { + const auto iter = kStringToType->find(type_str); + if (iter == kStringToType->end()) { + break; + } + result.push_back(iter->second); + } + return result; +} + +int GetRaggedRank(const std::vector& row_partition_types) { + if (row_partition_types.empty()) { + return 0; + } + if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE) { + return row_partition_types.size() - 1; + } + return row_partition_types.size(); +} +} // namespace tensorflow diff --git a/tensorflow/core/util/ragged_to_dense_util_common.h b/tensorflow/core/util/ragged_to_dense_util_common.h new file mode 100644 index 00000000000..b43412adb59 --- /dev/null +++ b/tensorflow/core/util/ragged_to_dense_util_common.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_RAGGED_TO_DENSE_UTIL_COMMON_H_ +#define TENSORFLOW_CORE_UTIL_RAGGED_TO_DENSE_UTIL_COMMON_H_ + +#include +#include + +namespace tensorflow { +enum class RowPartitionType { + FIRST_DIM_SIZE, + VALUE_ROWIDS, + ROW_LENGTHS, + ROW_SPLITS, + ROW_LIMITS, + ROW_STARTS +}; + +std::string RowPartitionTypeToString(RowPartitionType row_partition_type); + +std::vector GetRowPartitionTypesHelper( + const std::vector& row_partition_type_strings); + +int GetRaggedRank(const std::vector& row_partition_types); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_RAGGED_TO_DENSE_UTIL_COMMON_H_