add a tensorflow::batch_util::CopyContiguousSlices utility function for
slicing out a contiguous pieces of tensors along the batch dimension and copying them to another tensor. PiperOrigin-RevId: 313414257 Change-Id: I2530c58ed53ad8e92e5f976f2dd1728296d12185
This commit is contained in:
parent
f809169da0
commit
a67ee929f5
@ -1008,6 +1008,7 @@ tf_cc_tests(
|
||||
srcs = [
|
||||
"allocator_test.cc",
|
||||
"attr_value_util_test.cc",
|
||||
"batch_util_test.cc",
|
||||
"bfloat16_test.cc",
|
||||
"cancellation_test.cc",
|
||||
"common_shape_fns_test.cc",
|
||||
|
61
tensorflow/core/framework/batch_util_test.cc
Normal file
61
tensorflow/core/framework/batch_util_test.cc
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(CopyContiguousSlicesTest, CompatibleShape) {
|
||||
Tensor src(DT_FLOAT, {7, 1, 2});
|
||||
Tensor dst(DT_FLOAT, {9, 2, 1});
|
||||
auto s = batch_util::CopyContiguousSlices(
|
||||
src, /*src_offset=*/2, /*dst_offset=*/0, /*num_slices=*/5, &dst);
|
||||
ASSERT_EQ(error::OK, s.code());
|
||||
}
|
||||
|
||||
TEST(CopyContiguousSlicesTest, SourceOffsetOutOfRange) {
|
||||
Tensor src(DT_FLOAT, {7, 1, 2});
|
||||
Tensor dst(DT_FLOAT, {9, 2, 1});
|
||||
auto s = batch_util::CopyContiguousSlices(
|
||||
src, /*src_offset=*/7, /*dst_offset=*/0, /*num_slices=*/5, &dst);
|
||||
ASSERT_EQ(error::FAILED_PRECONDITION, s.code());
|
||||
}
|
||||
|
||||
TEST(CopyContiguousSlicesTest, DstOffsetOutOfRange) {
|
||||
Tensor src(DT_FLOAT, {7, 1, 2});
|
||||
Tensor dst(DT_FLOAT, {9, 2, 1});
|
||||
auto s = batch_util::CopyContiguousSlices(
|
||||
src, /*src_offset=*/0, /*dst_offset=*/0, /*num_slices=*/8, &dst);
|
||||
ASSERT_EQ(error::FAILED_PRECONDITION, s.code());
|
||||
}
|
||||
|
||||
TEST(CopyContiguousSlicesTest, CheckDstWithExpectedValues) {
|
||||
auto src = test::AsTensor<float>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
TensorShape({5, 2}));
|
||||
Tensor dst(DT_FLOAT, {9, 2, 1});
|
||||
auto s = batch_util::CopyContiguousSlices(
|
||||
src, /*src_offset=*/1, /*dst_offset=*/5, /*num_slices=*/3, &dst);
|
||||
ASSERT_EQ(error::OK, s.code());
|
||||
test::ExpectTensorEqual<float>(
|
||||
test::AsTensor<float>({2, 3, 4, 5, 6, 7}, TensorShape({3, 2, 1})),
|
||||
dst.Slice(5, 8));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -53,6 +53,8 @@ namespace batch_util {
|
||||
Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
|
||||
Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index);
|
||||
Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
|
||||
Status CopyContiguousSlices(const Tensor& src, int64 src_offset,
|
||||
int64 dst_offset, int64 num_slices, Tensor* dst);
|
||||
} // namespace batch_util
|
||||
|
||||
/// @ingroup core
|
||||
@ -679,6 +681,9 @@ class Tensor {
|
||||
friend Status batch_util::MaybeMoveSliceToElement(
|
||||
Tensor* parent, Tensor* element,
|
||||
int64 index); // For access to base<T>().
|
||||
friend Status batch_util::CopyContiguousSlices(
|
||||
const Tensor& src, int64 src_offset, int64 dst_offset, int64 num_slices,
|
||||
Tensor* dst); // For access to base<T>().
|
||||
|
||||
bool CanUseDMA() const;
|
||||
|
||||
|
@ -216,6 +216,79 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
|
||||
}
|
||||
}
|
||||
|
||||
Status CopyContiguousSlices(const Tensor& src, int64 src_offset,
|
||||
int64 dst_offset, int64 num_slices, Tensor* dst) {
|
||||
if (src.dtype() != dst->dtype()) {
|
||||
return errors::FailedPrecondition(
|
||||
"CopyContiguousSlices cannot perform copy: src and dst have different "
|
||||
"dtypes. Source dtype: ",
|
||||
src.dtype(), " dstination dtype: ", dst->dtype(), ".");
|
||||
}
|
||||
if (src.dims() < 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"CopyContiguousSlices cannot perform copy: src has to be a tensor with "
|
||||
"rank >= 1. Source shape: ",
|
||||
src.shape().DebugString());
|
||||
}
|
||||
|
||||
if (dst->dims() < 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"CopyContiguousSlices cannot perform copy: dst has to be a tensor "
|
||||
"with rank >= 1. Dest shape: ",
|
||||
dst->shape().DebugString());
|
||||
}
|
||||
|
||||
const int64 src_dim0 = src.dim_size(0);
|
||||
const int64 dst_dim0 = dst->dim_size(0);
|
||||
int64 src_chip_size = 1;
|
||||
int64 dst_chip_size = 1;
|
||||
for (int i = 1; i < src.dims(); ++i) {
|
||||
src_chip_size *= src.dim_size(i);
|
||||
}
|
||||
for (int i = 1; i < dst->dims(); ++i) {
|
||||
dst_chip_size *= dst->dim_size(i);
|
||||
}
|
||||
|
||||
if (src_chip_size != dst_chip_size) {
|
||||
return errors::FailedPrecondition(
|
||||
"CopyContiguousSlices cannot perform copy: source and dst shapes are"
|
||||
"not compatible. Source shape: ",
|
||||
src.shape().DebugString(), ", dst shape: ", dst->shape().DebugString());
|
||||
}
|
||||
|
||||
if (src_chip_size == 0 && dst_chip_size == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (src_offset < 0 || src_offset + num_slices > src_dim0 || dst_offset < 0 ||
|
||||
dst_offset + num_slices > dst_dim0) {
|
||||
return errors::FailedPrecondition(
|
||||
"CopyContiguousSlices cannot perform copy: index out of range. "
|
||||
"src_offset: ",
|
||||
src_offset, ", num_slices: ", num_slices, ", src_dim0: ", src_dim0,
|
||||
", dst_offset: ", dst_offset, ", dst_dim0: ", dst_dim0, ".");
|
||||
}
|
||||
|
||||
#define HANDLE_TYPE(T) \
|
||||
case DataTypeToEnum<T>::value: { \
|
||||
const T* src_p = src.base<T>() + (src_chip_size * src_offset); \
|
||||
T* dst_p = dst->base<T>() + (dst_chip_size * dst_offset); \
|
||||
HandleSliceToElement<T>(src_p, dst_p, src_chip_size * num_slices); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
switch (src.dtype()) {
|
||||
TF_CALL_ALL_TYPES(HANDLE_TYPE);
|
||||
TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
|
||||
TF_CALL_uint32(HANDLE_TYPE);
|
||||
TF_CALL_uint64(HANDLE_TYPE);
|
||||
#undef HANDLE_TYPE
|
||||
default:
|
||||
return errors::Unimplemented("CopyContiguousSlices unhandled data type: ",
|
||||
src.dtype());
|
||||
}
|
||||
}
|
||||
|
||||
// Copies the index^th slice of parent (in the 0th dimension) into element.
|
||||
//
|
||||
// NOTE(mrry): The implementation may be able to optimize the copy to a move.
|
||||
|
@ -32,6 +32,17 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
|
||||
// Copies the index^th slice of parent (in the 0th dimension) into element.
|
||||
Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index);
|
||||
|
||||
// Copies 'num_slices' contiguous slices from 'src' tensor starting from index
|
||||
// 'src_offset' into target tensor 'dst', and places them into slices
|
||||
// starting from 'dst_offset'.
|
||||
//
|
||||
// This function requires 'src' and 'dst' to have compatible shapes. That is it
|
||||
// requires cum_prod(src.shape[1:] == cum_prod(dst->shape[1:]). For example if
|
||||
// source is of shape [x, 2, 1] and dst is a tensor of shape [y, 1, 2], this
|
||||
// function can still proceed successfully.
|
||||
Status CopyContiguousSlices(const Tensor& src, int64 src_offset,
|
||||
int64 dst_offset, int64 num_slices, Tensor* dst);
|
||||
|
||||
// Copies the index^th slice of parent (in the 0th dimension) into element.
|
||||
//
|
||||
// NOTE(mrry): The implementation may be able to optimize the copy to a move.
|
||||
|
Loading…
Reference in New Issue
Block a user