Remove unused checkpointing code from TensorSliceSet.
PiperOrigin-RevId: 261022598
This commit is contained in:
parent
9bdc9dbf52
commit
a1b67023fe
tensorflow/core/util
@ -30,8 +30,7 @@ TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type)
|
||||
|
||||
TensorSliceSet::~TensorSliceSet() {}
|
||||
|
||||
Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag,
|
||||
const float* data) {
|
||||
Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag) {
|
||||
TensorShape result_shape;
|
||||
TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape));
|
||||
string str = slice.DebugString();
|
||||
@ -53,69 +52,11 @@ Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag,
|
||||
slices_hull_.UpdateToCover(slice);
|
||||
}
|
||||
|
||||
TensorSliceSet::SliceInfo info = {slice, tag, data,
|
||||
result_shape.num_elements()};
|
||||
TensorSliceSet::SliceInfo info = {slice, tag, result_shape.num_elements()};
|
||||
slices_.insert(std::make_pair(str, info));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(yangke): merge Query() with QueryMeta()
|
||||
bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const {
|
||||
Status s;
|
||||
string str = slice.DebugString();
|
||||
// First we check if there is an exactly match (this is the dominant case).
|
||||
const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
|
||||
if (info) {
|
||||
if (data) {
|
||||
std::copy_n(info->data, info->num_floats, data);
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
// We didn't find any exact match but there is still a possibility that
|
||||
// multiple existing slices can be patched together to output the slice.
|
||||
// We figure this out by computing the intersection of each of the existing
|
||||
// slices with the query slice, and check if the union of all these
|
||||
// intersections cover the entire slice. We rely on the fact that the
|
||||
// existing slices don't have any intersection among themselves.
|
||||
TensorShape target_shape;
|
||||
Status s;
|
||||
s = slice.SliceTensorShape(shape_, &target_shape);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << s;
|
||||
return false;
|
||||
}
|
||||
int64 total_size = target_shape.num_elements();
|
||||
|
||||
int64 overlap_size = 0;
|
||||
TensorSlice intersection;
|
||||
TensorShape inter_shape;
|
||||
for (const auto& x : slices_) {
|
||||
if (slice.Intersect(x.second.slice, &intersection)) {
|
||||
s = intersection.SliceTensorShape(shape_, &inter_shape);
|
||||
if (!s.ok()) {
|
||||
LOG(WARNING) << s;
|
||||
return false;
|
||||
}
|
||||
overlap_size += inter_shape.num_elements();
|
||||
}
|
||||
}
|
||||
if (total_size == overlap_size) {
|
||||
// We have it!
|
||||
// Now we need to copy the data to "data"
|
||||
if (data) {
|
||||
for (const auto& x : slices_) {
|
||||
CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice,
|
||||
x.second.data, data);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
// We don't have all the data for the asked tensor slice
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorSliceSet::QueryMeta(
|
||||
const TensorSlice& slice,
|
||||
std::vector<std::pair<TensorSlice, string>>* results) const {
|
||||
@ -194,7 +135,7 @@ Status RegisterTensorSlice(
|
||||
}
|
||||
}
|
||||
// Register the tensor slices without the actual data.
|
||||
return tss->Register(slice, tag, nullptr);
|
||||
return tss->Register(slice, tag);
|
||||
}
|
||||
|
||||
} // namespace checkpoint
|
||||
|
@ -16,11 +16,8 @@ limitations under the License.
|
||||
// A class to manage slices of a tensor. You can "register" set of slices for a
|
||||
// tensor and then "query" if we have data for a given slice.
|
||||
|
||||
// TODO(yangke): consider moving it to a more private place so that we don't
|
||||
// need to expose the API.
|
||||
|
||||
#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
|
||||
#define TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
|
||||
#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
|
||||
#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
|
||||
|
||||
#include <string> // for string
|
||||
#include <unordered_map>
|
||||
@ -49,18 +46,7 @@ class TensorSliceSet {
|
||||
// associated with the slice (in one application it denotes the name of the
|
||||
// file that contains the slice); the "data" points to the data of the tensor
|
||||
// slice (it can be a nullptr).
|
||||
// We don't take the ownership of "data" and the caller needs to make sure
|
||||
// the data is always available during the life time of the tensor slice set
|
||||
// if it is not nullptr.
|
||||
Status Register(const TensorSlice& slice, const string& tag,
|
||||
const float* data);
|
||||
|
||||
// Query about a new slice: checks if we have data for "slice" and if we have
|
||||
// the data and "data" is not nullptr, fill "data" with the slice data. The
|
||||
// caller needs to make sure "data" point to a large enough buffer.
|
||||
// TODO(yangke): avoid unnecessary copying by using a core::RefCounted
|
||||
// pointer.
|
||||
bool Query(const TensorSlice& slice, float* data) const;
|
||||
Status Register(const TensorSlice& slice, const string& tag);
|
||||
|
||||
// Alternative way of querying about a new slice: instead of copying the
|
||||
// data, it returns a list of meta data about the stored slices that will
|
||||
@ -72,7 +58,6 @@ class TensorSliceSet {
|
||||
struct SliceInfo {
|
||||
TensorSlice slice;
|
||||
const string tag;
|
||||
const float* data;
|
||||
int64 num_floats;
|
||||
};
|
||||
|
||||
@ -105,4 +90,4 @@ Status RegisterTensorSlice(
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
|
||||
#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
|
||||
|
@ -36,107 +36,6 @@ namespace {
|
||||
//
|
||||
// We assume this is a row-major matrix.
|
||||
//
|
||||
// We store the tensor in a couple of slices and verify that we can recover all
|
||||
// of them.
|
||||
TEST(TensorSliceSetTest, QueryTwoD) {
|
||||
TensorShape shape({4, 5});
|
||||
|
||||
TensorSliceSet tss(shape, DT_FLOAT);
|
||||
// We store a few slices.
|
||||
|
||||
// Slice #1 is the top two rows:
|
||||
// 0 1 2 3 4
|
||||
// 5 6 7 8 9
|
||||
// . . . . .
|
||||
// . . . . .
|
||||
const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||
TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
|
||||
TF_CHECK_OK(tss.Register(slice_1, "", src_1));
|
||||
|
||||
// Slice #2 is the bottom left corner
|
||||
// . . . . .
|
||||
// . . . . .
|
||||
// 10 11 12 . .
|
||||
// 15 16 17 . .
|
||||
const float src_2[] = {10, 11, 12, 15, 16, 17};
|
||||
TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
|
||||
TF_CHECK_OK(tss.Register(slice_2, "", src_2));
|
||||
|
||||
// Slice #3 is the bottom right corner
|
||||
// . . . . .
|
||||
// . . . . .
|
||||
// . . . . .
|
||||
// . . . 18 19
|
||||
const float src_3[] = {18, 19};
|
||||
TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
|
||||
TF_CHECK_OK(tss.Register(slice_3, "", src_3));
|
||||
|
||||
// Notice that we leave a hole in the tensor
|
||||
// . . . . .
|
||||
// . . . . .
|
||||
// . . . (13) (14)
|
||||
// . . . . .
|
||||
|
||||
// Now we query some of the slices
|
||||
|
||||
// Slice #1 is an exact match
|
||||
// 0 1 2 3 4
|
||||
// 5 6 7 8 9
|
||||
// . . . . .
|
||||
// . . . . .
|
||||
{
|
||||
TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
|
||||
float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||
float results[10];
|
||||
EXPECT_TRUE(tss.Query(s, results));
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
EXPECT_EQ(expected[i], results[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Slice #2 is a subset match
|
||||
// . . . . .
|
||||
// 5 6 7 8 9
|
||||
// . . . . .
|
||||
// . . . . .
|
||||
{
|
||||
TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
|
||||
float expected[] = {5, 6, 7, 8, 9};
|
||||
float results[5];
|
||||
EXPECT_TRUE(tss.Query(s, results));
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
EXPECT_EQ(expected[i], results[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Slice #3 is a more complicated match: it needs the combination of a couple
|
||||
// of slices
|
||||
// . . . . .
|
||||
// 5 6 7 . .
|
||||
// 10 11 12 . .
|
||||
// . . . . .
|
||||
{
|
||||
TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
|
||||
float expected[] = {5, 6, 7, 10, 11, 12};
|
||||
float results[6];
|
||||
EXPECT_TRUE(tss.Query(s, results));
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
EXPECT_EQ(expected[i], results[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Slice #4 includes the hole and so there is no match
|
||||
// . . . . .
|
||||
// . . 7 8 9
|
||||
// . . 12 13 14
|
||||
// . . . . .
|
||||
{
|
||||
TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
|
||||
float results[6];
|
||||
EXPECT_FALSE(tss.Query(s, results));
|
||||
}
|
||||
}
|
||||
|
||||
// Testing the meta version of the tensor slice set.
|
||||
TEST(TensorSliceSetTest, QueryMetaTwoD) {
|
||||
TensorShape shape({4, 5});
|
||||
@ -150,7 +49,7 @@ TEST(TensorSliceSetTest, QueryMetaTwoD) {
|
||||
// . . . . .
|
||||
// . . . . .
|
||||
TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
|
||||
TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr));
|
||||
TF_CHECK_OK(tss.Register(slice_1, "slice_1"));
|
||||
|
||||
// Slice #2 is the bottom left corner
|
||||
// . . . . .
|
||||
@ -158,7 +57,7 @@ TEST(TensorSliceSetTest, QueryMetaTwoD) {
|
||||
// 10 11 12 . .
|
||||
// 15 16 17 . .
|
||||
TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
|
||||
TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr));
|
||||
TF_CHECK_OK(tss.Register(slice_2, "slice_2"));
|
||||
|
||||
// Slice #3 is the bottom right corner
|
||||
// . . . . .
|
||||
@ -166,7 +65,7 @@ TEST(TensorSliceSetTest, QueryMetaTwoD) {
|
||||
// . . . . .
|
||||
// . . . 18 19
|
||||
TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
|
||||
TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr));
|
||||
TF_CHECK_OK(tss.Register(slice_3, "slice_3"));
|
||||
|
||||
// Notice that we leave a hole in the tensor
|
||||
// . . . . .
|
||||
@ -250,7 +149,7 @@ static void BM_RegisterOneByOne(int parts) {
|
||||
TensorSliceSet slice_set(shape, DT_INT32);
|
||||
for (int i = 0; i < parts; ++i) {
|
||||
TensorSlice part({{i, 1}, {0, -1}});
|
||||
TF_CHECK_OK(slice_set.Register(part, part.DebugString(), nullptr));
|
||||
TF_CHECK_OK(slice_set.Register(part, part.DebugString()));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user