Moving common with TFLite ragged tensor code to a separate file.

PiperOrigin-RevId: 329754601
Change-Id: Id7695bad73d18a8f54fdff7a8e1898c37477f13f
This commit is contained in:
A. Unique TensorFlower 2020-09-02 11:37:44 -07:00 committed by TensorFlower Gardener
parent dfc58844f9
commit c59af2913a
5 changed files with 138 additions and 53 deletions

View File

@ -16,8 +16,8 @@ load(
"tf_cuda_only_cc_test",
"tf_kernel_library",
)
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
load(
"//third_party/mkl:build_defs.bzl",
"mkl_deps",
@ -106,6 +106,8 @@ filegroup(
"ptr_util.h",
"ragged_to_dense_util.cc",
"ragged_to_dense_util.h",
"ragged_to_dense_util_common.cc",
"ragged_to_dense_util_common.h",
"reffed_status_callback.h",
"saved_tensor_slice_util.cc",
"saved_tensor_slice_util.h",
@ -383,6 +385,19 @@ cc_library(
],
)
cc_library(
name = "ragged_to_dense_util_common",
srcs = [
"ragged_to_dense_util_common.cc",
],
hdrs = [
"ragged_to_dense_util_common.h",
],
visibility = [
"//visibility:public",
],
)
cc_library(
name = "ragged_to_dense_util",
srcs = [
@ -392,6 +407,7 @@ cc_library(
"ragged_to_dense_util.h",
],
deps = [
":ragged_to_dense_util_common",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
],

View File

@ -24,43 +24,15 @@ namespace tensorflow {
using errors::InvalidArgument;
string RowPartitionTypeToString(RowPartitionType row_partition_type) {
switch (row_partition_type) {
case RowPartitionType::FIRST_DIM_SIZE:
return "FIRST_DIM_SIZE";
case RowPartitionType::VALUE_ROWIDS:
return "VALUE_ROWIDS";
case RowPartitionType::ROW_LENGTHS:
return "ROW_LENGTHS";
case RowPartitionType::ROW_SPLITS:
return "ROW_SPLITS";
case RowPartitionType::ROW_LIMITS:
return "ROW_LIMITS";
case RowPartitionType::ROW_STARTS:
return "ROW_STARTS";
default:
return "UNKNOWN ROW PARTITION TYPE";
}
}
tensorflow::Status GetRowPartitionTypesHelper(
const std::vector<string>& row_partition_type_strings,
std::vector<RowPartitionType>* row_partition_types) {
static const auto kStringToType =
new std::unordered_map<string, RowPartitionType>(
{{"FIRST_DIM_SIZE", RowPartitionType::FIRST_DIM_SIZE},
{"VALUE_ROWIDS", RowPartitionType::VALUE_ROWIDS},
{"ROW_LENGTHS", RowPartitionType::ROW_LENGTHS},
{"ROW_SPLITS", RowPartitionType::ROW_SPLITS},
{"ROW_LIMITS", RowPartitionType::ROW_LIMITS},
{"ROW_STARTS", RowPartitionType::ROW_STARTS}});
for (const string& type_str : row_partition_type_strings) {
const auto iter = kStringToType->find(type_str);
if (iter == kStringToType->end()) {
return InvalidArgument("Unknown string for partition info type: ",
type_str);
}
row_partition_types->push_back(iter->second);
*row_partition_types = GetRowPartitionTypesHelper(row_partition_type_strings);
if (row_partition_types->size() != row_partition_type_strings.size()) {
// Something was not converted, return error status.
return InvalidArgument(
"Unknown string for partition info type: ",
row_partition_type_strings.at(row_partition_types->size()));
}
return tensorflow::Status::OK();
}
@ -120,16 +92,6 @@ tensorflow::Status CombineRaggedTensorToTensorShapes(
return tensorflow::Status::OK();
}
int GetRaggedRank(const std::vector<RowPartitionType>& row_partition_types) {
if (row_partition_types.empty()) {
return 0;
}
if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE) {
return row_partition_types.size() - 1;
}
return row_partition_types.size();
}
tensorflow::Status ValidateDefaultValueShape(
const TensorShapeProto& default_value_shape,
const TensorShapeProto& value_shape) {

View File

@ -20,16 +20,9 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/util/ragged_to_dense_util_common.h"
namespace tensorflow {
enum class RowPartitionType {
FIRST_DIM_SIZE,
VALUE_ROWIDS,
ROW_LENGTHS,
ROW_SPLITS,
ROW_LIMITS,
ROW_STARTS
};
string RowPartitionTypeToString(RowPartitionType row_partition_type);
@ -48,6 +41,10 @@ Status GetRowPartitionTypes(
row_partition_types);
}
Status GetRowPartitionTypesHelper(
const std::vector<string>& row_partition_type_strings,
std::vector<RowPartitionType>* row_partition_types);
Status CombineRaggedTensorToTensorShapes(int ragged_rank,
const TensorShapeProto& shape,
const TensorShapeProto& value_shape,

View File

@ -0,0 +1,70 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/ragged_to_dense_util_common.h"
#include <unordered_map>
namespace tensorflow {
std::string RowPartitionTypeToString(RowPartitionType row_partition_type) {
switch (row_partition_type) {
case RowPartitionType::FIRST_DIM_SIZE:
return "FIRST_DIM_SIZE";
case RowPartitionType::VALUE_ROWIDS:
return "VALUE_ROWIDS";
case RowPartitionType::ROW_LENGTHS:
return "ROW_LENGTHS";
case RowPartitionType::ROW_SPLITS:
return "ROW_SPLITS";
case RowPartitionType::ROW_LIMITS:
return "ROW_LIMITS";
case RowPartitionType::ROW_STARTS:
return "ROW_STARTS";
default:
return "UNKNOWN ROW PARTITION TYPE";
}
}
std::vector<RowPartitionType> GetRowPartitionTypesHelper(
const std::vector<std::string>& row_partition_type_strings) {
static const auto kStringToType =
new std::unordered_map<std::string, RowPartitionType>(
{{"FIRST_DIM_SIZE", RowPartitionType::FIRST_DIM_SIZE},
{"VALUE_ROWIDS", RowPartitionType::VALUE_ROWIDS},
{"ROW_LENGTHS", RowPartitionType::ROW_LENGTHS},
{"ROW_SPLITS", RowPartitionType::ROW_SPLITS},
{"ROW_LIMITS", RowPartitionType::ROW_LIMITS},
{"ROW_STARTS", RowPartitionType::ROW_STARTS}});
std::vector<RowPartitionType> result;
for (const auto& type_str : row_partition_type_strings) {
const auto iter = kStringToType->find(type_str);
if (iter == kStringToType->end()) {
break;
}
result.push_back(iter->second);
}
return result;
}
int GetRaggedRank(const std::vector<RowPartitionType>& row_partition_types) {
if (row_partition_types.empty()) {
return 0;
}
if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE) {
return row_partition_types.size() - 1;
}
return row_partition_types.size();
}
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_RAGGED_TO_DENSE_UTIL_COMMON_H_
#define TENSORFLOW_CORE_UTIL_RAGGED_TO_DENSE_UTIL_COMMON_H_
#include <string>
#include <vector>
namespace tensorflow {
enum class RowPartitionType {
FIRST_DIM_SIZE,
VALUE_ROWIDS,
ROW_LENGTHS,
ROW_SPLITS,
ROW_LIMITS,
ROW_STARTS
};
std::string RowPartitionTypeToString(RowPartitionType row_partition_type);
std::vector<RowPartitionType> GetRowPartitionTypesHelper(
const std::vector<std::string>& row_partition_type_strings);
int GetRaggedRank(const std::vector<RowPartitionType>& row_partition_types);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RAGGED_TO_DENSE_UTIL_COMMON_H_