[tf.data] Further optimizations for SerializeManySparseOp.
1. Use DMAHelper to access the tensor base pointers without additional alignment and type checks, and use these pointers to access the elements of the tensors directly. 2. Add a special case for rank == 2 (which is the common case when batching Example protos), to avoid a length-1 loop per element. 3. Use `memcpy` where possible (and otherwise, `std::copy_n`) instead of Eigen assignment for the group values. PiperOrigin-RevId: 291176480 Change-Id: I331213c0ac1caadf620c87833759b8a6550f1752
This commit is contained in:
parent
fcad65986c
commit
649a04fbe9
tensorflow/core/kernels
@ -5269,6 +5269,7 @@ tf_kernel_library(
|
||||
prefix = "serialize_sparse_op",
|
||||
deps = SPARSE_DEPS + [
|
||||
":reshape_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -215,6 +216,36 @@ struct SerializeGroups<T, tstring> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void CopyValues(const T* src, T* dest, int64 num_values) {
|
||||
static_assert(is_simple_type<T>::value, "Memcpy requires a simple type.");
|
||||
memcpy(dest, src, num_values * sizeof(T));
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyValues<tstring>(const tstring* src, tstring* dest, int64 num_values) {
|
||||
std::copy_n(src, num_values, dest);
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyValues<Variant>(const Variant* src, Variant* dest, int64 num_values) {
|
||||
std::copy_n(src, num_values, dest);
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyValues<ResourceHandle>(const ResourceHandle* src, ResourceHandle* dest,
|
||||
int64 num_values) {
|
||||
std::copy_n(src, num_values, dest);
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyValues<Eigen::half>(const Eigen::half* src, Eigen::half* dest,
|
||||
int64 num_values) {
|
||||
return CopyValues(reinterpret_cast<const char*>(src),
|
||||
reinterpret_cast<char*>(dest),
|
||||
num_values * sizeof(Eigen::half));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct SerializeGroups<T, Variant> {
|
||||
Status operator()(sparse::GroupIterable* minibatch,
|
||||
@ -263,16 +294,30 @@ struct SerializeGroups<T, Variant> {
|
||||
Tensor& output_values = serialized_sparse_t(b, 1).emplace<Tensor>(
|
||||
T_type, TensorShape({num_entries}));
|
||||
|
||||
auto output_indices_t = output_indices.matrix<int64>();
|
||||
auto output_values_t = output_values.vec<T>();
|
||||
int64* output_indices_ptr =
|
||||
static_cast<int64*>(DMAHelper::base(&output_indices));
|
||||
const int64* indices_ptr = indices.data();
|
||||
|
||||
for (int i = 0; i < num_entries; ++i) {
|
||||
for (int d = 1; d < rank; ++d) {
|
||||
output_indices_t(i, d - 1) = indices(i, d);
|
||||
T* output_values_ptr = static_cast<T*>(DMAHelper::base(&output_values));
|
||||
const T* values_ptr = values.data();
|
||||
|
||||
// TODO(mrry): Consider adding a template-based specialization for higher
|
||||
// ranks.
|
||||
if (rank == 2) {
|
||||
for (int i = 0; i < num_entries; ++i) {
|
||||
output_indices_ptr[i] = indices_ptr[(2 * i) + 1];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < num_entries; ++i) {
|
||||
// Skip the first index in each row.
|
||||
++indices_ptr;
|
||||
for (int d = 1; d < rank; ++d) {
|
||||
*output_indices_ptr++ = *indices_ptr++;
|
||||
}
|
||||
}
|
||||
output_values_t(i) = values(i);
|
||||
}
|
||||
|
||||
CopyValues(values_ptr, output_values_ptr, num_entries);
|
||||
serialized_sparse_t(b, 2).emplace<Tensor>(output_shape);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user