Moving common with TFLite ragged tensor code to a separate file.
PiperOrigin-RevId: 329754601 Change-Id: Id7695bad73d18a8f54fdff7a8e1898c37477f13f
This commit is contained in:
parent
dfc58844f9
commit
c59af2913a
@ -16,8 +16,8 @@ load(
|
||||
"tf_cuda_only_cc_test",
|
||||
"tf_kernel_library",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
|
||||
load("//tensorflow:tensorflow.bzl", "get_compatible_with_portable")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"mkl_deps",
|
||||
@ -106,6 +106,8 @@ filegroup(
|
||||
"ptr_util.h",
|
||||
"ragged_to_dense_util.cc",
|
||||
"ragged_to_dense_util.h",
|
||||
"ragged_to_dense_util_common.cc",
|
||||
"ragged_to_dense_util_common.h",
|
||||
"reffed_status_callback.h",
|
||||
"saved_tensor_slice_util.cc",
|
||||
"saved_tensor_slice_util.h",
|
||||
@ -383,6 +385,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ragged_to_dense_util_common",
|
||||
srcs = [
|
||||
"ragged_to_dense_util_common.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ragged_to_dense_util_common.h",
|
||||
],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ragged_to_dense_util",
|
||||
srcs = [
|
||||
@ -392,6 +407,7 @@ cc_library(
|
||||
"ragged_to_dense_util.h",
|
||||
],
|
||||
deps = [
|
||||
":ragged_to_dense_util_common",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
|
@ -24,43 +24,15 @@ namespace tensorflow {
|
||||
|
||||
using errors::InvalidArgument;
|
||||
|
||||
string RowPartitionTypeToString(RowPartitionType row_partition_type) {
|
||||
switch (row_partition_type) {
|
||||
case RowPartitionType::FIRST_DIM_SIZE:
|
||||
return "FIRST_DIM_SIZE";
|
||||
case RowPartitionType::VALUE_ROWIDS:
|
||||
return "VALUE_ROWIDS";
|
||||
case RowPartitionType::ROW_LENGTHS:
|
||||
return "ROW_LENGTHS";
|
||||
case RowPartitionType::ROW_SPLITS:
|
||||
return "ROW_SPLITS";
|
||||
case RowPartitionType::ROW_LIMITS:
|
||||
return "ROW_LIMITS";
|
||||
case RowPartitionType::ROW_STARTS:
|
||||
return "ROW_STARTS";
|
||||
default:
|
||||
return "UNKNOWN ROW PARTITION TYPE";
|
||||
}
|
||||
}
|
||||
tensorflow::Status GetRowPartitionTypesHelper(
|
||||
const std::vector<string>& row_partition_type_strings,
|
||||
std::vector<RowPartitionType>* row_partition_types) {
|
||||
static const auto kStringToType =
|
||||
new std::unordered_map<string, RowPartitionType>(
|
||||
{{"FIRST_DIM_SIZE", RowPartitionType::FIRST_DIM_SIZE},
|
||||
{"VALUE_ROWIDS", RowPartitionType::VALUE_ROWIDS},
|
||||
{"ROW_LENGTHS", RowPartitionType::ROW_LENGTHS},
|
||||
{"ROW_SPLITS", RowPartitionType::ROW_SPLITS},
|
||||
{"ROW_LIMITS", RowPartitionType::ROW_LIMITS},
|
||||
{"ROW_STARTS", RowPartitionType::ROW_STARTS}});
|
||||
|
||||
for (const string& type_str : row_partition_type_strings) {
|
||||
const auto iter = kStringToType->find(type_str);
|
||||
if (iter == kStringToType->end()) {
|
||||
return InvalidArgument("Unknown string for partition info type: ",
|
||||
type_str);
|
||||
}
|
||||
row_partition_types->push_back(iter->second);
|
||||
*row_partition_types = GetRowPartitionTypesHelper(row_partition_type_strings);
|
||||
if (row_partition_types->size() != row_partition_type_strings.size()) {
|
||||
// Something was not converted, return error status.
|
||||
return InvalidArgument(
|
||||
"Unknown string for partition info type: ",
|
||||
row_partition_type_strings.at(row_partition_types->size()));
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -120,16 +92,6 @@ tensorflow::Status CombineRaggedTensorToTensorShapes(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
int GetRaggedRank(const std::vector<RowPartitionType>& row_partition_types) {
|
||||
if (row_partition_types.empty()) {
|
||||
return 0;
|
||||
}
|
||||
if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE) {
|
||||
return row_partition_types.size() - 1;
|
||||
}
|
||||
return row_partition_types.size();
|
||||
}
|
||||
|
||||
tensorflow::Status ValidateDefaultValueShape(
|
||||
const TensorShapeProto& default_value_shape,
|
||||
const TensorShapeProto& value_shape) {
|
||||
|
@ -20,16 +20,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/util/ragged_to_dense_util_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
enum class RowPartitionType {
|
||||
FIRST_DIM_SIZE,
|
||||
VALUE_ROWIDS,
|
||||
ROW_LENGTHS,
|
||||
ROW_SPLITS,
|
||||
ROW_LIMITS,
|
||||
ROW_STARTS
|
||||
};
|
||||
|
||||
string RowPartitionTypeToString(RowPartitionType row_partition_type);
|
||||
|
||||
@ -48,6 +41,10 @@ Status GetRowPartitionTypes(
|
||||
row_partition_types);
|
||||
}
|
||||
|
||||
Status GetRowPartitionTypesHelper(
|
||||
const std::vector<string>& row_partition_type_strings,
|
||||
std::vector<RowPartitionType>* row_partition_types);
|
||||
|
||||
Status CombineRaggedTensorToTensorShapes(int ragged_rank,
|
||||
const TensorShapeProto& shape,
|
||||
const TensorShapeProto& value_shape,
|
||||
|
70
tensorflow/core/util/ragged_to_dense_util_common.cc
Normal file
70
tensorflow/core/util/ragged_to_dense_util_common.cc
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/ragged_to_dense_util_common.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace tensorflow {
|
||||
std::string RowPartitionTypeToString(RowPartitionType row_partition_type) {
|
||||
switch (row_partition_type) {
|
||||
case RowPartitionType::FIRST_DIM_SIZE:
|
||||
return "FIRST_DIM_SIZE";
|
||||
case RowPartitionType::VALUE_ROWIDS:
|
||||
return "VALUE_ROWIDS";
|
||||
case RowPartitionType::ROW_LENGTHS:
|
||||
return "ROW_LENGTHS";
|
||||
case RowPartitionType::ROW_SPLITS:
|
||||
return "ROW_SPLITS";
|
||||
case RowPartitionType::ROW_LIMITS:
|
||||
return "ROW_LIMITS";
|
||||
case RowPartitionType::ROW_STARTS:
|
||||
return "ROW_STARTS";
|
||||
default:
|
||||
return "UNKNOWN ROW PARTITION TYPE";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<RowPartitionType> GetRowPartitionTypesHelper(
|
||||
const std::vector<std::string>& row_partition_type_strings) {
|
||||
static const auto kStringToType =
|
||||
new std::unordered_map<std::string, RowPartitionType>(
|
||||
{{"FIRST_DIM_SIZE", RowPartitionType::FIRST_DIM_SIZE},
|
||||
{"VALUE_ROWIDS", RowPartitionType::VALUE_ROWIDS},
|
||||
{"ROW_LENGTHS", RowPartitionType::ROW_LENGTHS},
|
||||
{"ROW_SPLITS", RowPartitionType::ROW_SPLITS},
|
||||
{"ROW_LIMITS", RowPartitionType::ROW_LIMITS},
|
||||
{"ROW_STARTS", RowPartitionType::ROW_STARTS}});
|
||||
std::vector<RowPartitionType> result;
|
||||
for (const auto& type_str : row_partition_type_strings) {
|
||||
const auto iter = kStringToType->find(type_str);
|
||||
if (iter == kStringToType->end()) {
|
||||
break;
|
||||
}
|
||||
result.push_back(iter->second);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int GetRaggedRank(const std::vector<RowPartitionType>& row_partition_types) {
|
||||
if (row_partition_types.empty()) {
|
||||
return 0;
|
||||
}
|
||||
if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE) {
|
||||
return row_partition_types.size() - 1;
|
||||
}
|
||||
return row_partition_types.size();
|
||||
}
|
||||
} // namespace tensorflow
|
40
tensorflow/core/util/ragged_to_dense_util_common.h
Normal file
40
tensorflow/core/util/ragged_to_dense_util_common.h
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_RAGGED_TO_DENSE_UTIL_COMMON_H_
|
||||
#define TENSORFLOW_CORE_UTIL_RAGGED_TO_DENSE_UTIL_COMMON_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorflow {
|
||||
enum class RowPartitionType {
|
||||
FIRST_DIM_SIZE,
|
||||
VALUE_ROWIDS,
|
||||
ROW_LENGTHS,
|
||||
ROW_SPLITS,
|
||||
ROW_LIMITS,
|
||||
ROW_STARTS
|
||||
};
|
||||
|
||||
std::string RowPartitionTypeToString(RowPartitionType row_partition_type);
|
||||
|
||||
std::vector<RowPartitionType> GetRowPartitionTypesHelper(
|
||||
const std::vector<std::string>& row_partition_type_strings);
|
||||
|
||||
int GetRaggedRank(const std::vector<RowPartitionType>& row_partition_types);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_RAGGED_TO_DENSE_UTIL_COMMON_H_
|
Loading…
Reference in New Issue
Block a user