Remove unused checkpointing code from TensorSliceSet.

PiperOrigin-RevId: 261022598
This commit is contained in:
Derek Murray 2019-07-31 16:33:12 -07:00 committed by TensorFlower Gardener
parent 9bdc9dbf52
commit a1b67023fe
3 changed files with 11 additions and 186 deletions

View File

@ -30,8 +30,7 @@ TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type)
TensorSliceSet::~TensorSliceSet() {}
Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag,
const float* data) {
Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag) {
TensorShape result_shape;
TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape));
string str = slice.DebugString();
@ -53,69 +52,11 @@ Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag,
slices_hull_.UpdateToCover(slice);
}
TensorSliceSet::SliceInfo info = {slice, tag, data,
result_shape.num_elements()};
TensorSliceSet::SliceInfo info = {slice, tag, result_shape.num_elements()};
slices_.insert(std::make_pair(str, info));
return Status::OK();
}
// TODO(yangke): merge Query() with QueryMeta()
bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const {
Status s;
string str = slice.DebugString();
// First we check if there is an exactly match (this is the dominant case).
const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
if (info) {
if (data) {
std::copy_n(info->data, info->num_floats, data);
}
return true;
} else {
// We didn't find any exact match but there is still a possibility that
// multiple existing slices can be patched together to output the slice.
// We figure this out by computing the intersection of each of the existing
// slices with the query slice, and check if the union of all these
// intersections cover the entire slice. We rely on the fact that the
// existing slices don't have any intersection among themselves.
TensorShape target_shape;
Status s;
s = slice.SliceTensorShape(shape_, &target_shape);
if (!s.ok()) {
LOG(WARNING) << s;
return false;
}
int64 total_size = target_shape.num_elements();
int64 overlap_size = 0;
TensorSlice intersection;
TensorShape inter_shape;
for (const auto& x : slices_) {
if (slice.Intersect(x.second.slice, &intersection)) {
s = intersection.SliceTensorShape(shape_, &inter_shape);
if (!s.ok()) {
LOG(WARNING) << s;
return false;
}
overlap_size += inter_shape.num_elements();
}
}
if (total_size == overlap_size) {
// We have it!
// Now we need to copy the data to "data"
if (data) {
for (const auto& x : slices_) {
CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice,
x.second.data, data);
}
}
return true;
} else {
// We don't have all the data for the asked tensor slice
return false;
}
}
}
bool TensorSliceSet::QueryMeta(
const TensorSlice& slice,
std::vector<std::pair<TensorSlice, string>>* results) const {
@ -194,7 +135,7 @@ Status RegisterTensorSlice(
}
}
// Register the tensor slices without the actual data.
return tss->Register(slice, tag, nullptr);
return tss->Register(slice, tag);
}
} // namespace checkpoint

View File

@ -16,11 +16,8 @@ limitations under the License.
// A class to manage slices of a tensor. You can "register" set of slices for a
// tensor and then "query" if we have data for a given slice.
// TODO(yangke): consider moving it to a more private place so that we don't
// need to expose the API.
#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
#define TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_
#include <string> // for string
#include <unordered_map>
@ -49,18 +46,7 @@ class TensorSliceSet {
// associated with the slice (in one application it denotes the name of the
// file that contains the slice); the "data" points to the data of the tensor
// slice (it can be a nullptr).
// We don't take the ownership of "data" and the caller needs to make sure
// the data is always available during the life time of the tensor slice set
// if it is not nullptr.
Status Register(const TensorSlice& slice, const string& tag,
const float* data);
// Query about a new slice: checks if we have data for "slice" and if we have
// the data and "data" is not nullptr, fill "data" with the slice data. The
// caller needs to make sure "data" point to a large enough buffer.
// TODO(yangke): avoid unnecessary copying by using a core::RefCounted
// pointer.
bool Query(const TensorSlice& slice, float* data) const;
Status Register(const TensorSlice& slice, const string& tag);
// Alternative way of querying about a new slice: instead of copying the
// data, it returns a list of meta data about the stored slices that will
@ -72,7 +58,6 @@ class TensorSliceSet {
struct SliceInfo {
TensorSlice slice;
const string tag;
const float* data;
int64 num_floats;
};
@ -105,4 +90,4 @@ Status RegisterTensorSlice(
} // namespace tensorflow
#endif // TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_SET_H_

View File

@ -36,107 +36,6 @@ namespace {
//
// We assume this is a row-major matrix.
//
// We store the tensor in a couple of slices and verify that we can recover all
// of them.
TEST(TensorSliceSetTest, QueryTwoD) {
TensorShape shape({4, 5});
TensorSliceSet tss(shape, DT_FLOAT);
// We store a few slices.
// Slice #1 is the top two rows:
// 0 1 2 3 4
// 5 6 7 8 9
// . . . . .
// . . . . .
const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
TF_CHECK_OK(tss.Register(slice_1, "", src_1));
// Slice #2 is the bottom left corner
// . . . . .
// . . . . .
// 10 11 12 . .
// 15 16 17 . .
const float src_2[] = {10, 11, 12, 15, 16, 17};
TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
TF_CHECK_OK(tss.Register(slice_2, "", src_2));
// Slice #3 is the bottom right corner
// . . . . .
// . . . . .
// . . . . .
// . . . 18 19
const float src_3[] = {18, 19};
TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
TF_CHECK_OK(tss.Register(slice_3, "", src_3));
// Notice that we leave a hole in the tensor
// . . . . .
// . . . . .
// . . . (13) (14)
// . . . . .
// Now we query some of the slices
// Slice #1 is an exact match
// 0 1 2 3 4
// 5 6 7 8 9
// . . . . .
// . . . . .
{
TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
float results[10];
EXPECT_TRUE(tss.Query(s, results));
for (int i = 0; i < 10; ++i) {
EXPECT_EQ(expected[i], results[i]);
}
}
// Slice #2 is a subset match
// . . . . .
// 5 6 7 8 9
// . . . . .
// . . . . .
{
TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
float expected[] = {5, 6, 7, 8, 9};
float results[5];
EXPECT_TRUE(tss.Query(s, results));
for (int i = 0; i < 5; ++i) {
EXPECT_EQ(expected[i], results[i]);
}
}
// Slice #3 is a more complicated match: it needs the combination of a couple
// of slices
// . . . . .
// 5 6 7 . .
// 10 11 12 . .
// . . . . .
{
TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
float expected[] = {5, 6, 7, 10, 11, 12};
float results[6];
EXPECT_TRUE(tss.Query(s, results));
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(expected[i], results[i]);
}
}
// Slice #4 includes the hole and so there is no match
// . . . . .
// . . 7 8 9
// . . 12 13 14
// . . . . .
{
TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
float results[6];
EXPECT_FALSE(tss.Query(s, results));
}
}
// Testing the meta version of the tensor slice set.
TEST(TensorSliceSetTest, QueryMetaTwoD) {
TensorShape shape({4, 5});
@ -150,7 +49,7 @@ TEST(TensorSliceSetTest, QueryMetaTwoD) {
// . . . . .
// . . . . .
TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr));
TF_CHECK_OK(tss.Register(slice_1, "slice_1"));
// Slice #2 is the bottom left corner
// . . . . .
@ -158,7 +57,7 @@ TEST(TensorSliceSetTest, QueryMetaTwoD) {
// 10 11 12 . .
// 15 16 17 . .
TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr));
TF_CHECK_OK(tss.Register(slice_2, "slice_2"));
// Slice #3 is the bottom right corner
// . . . . .
@ -166,7 +65,7 @@ TEST(TensorSliceSetTest, QueryMetaTwoD) {
// . . . . .
// . . . 18 19
TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr));
TF_CHECK_OK(tss.Register(slice_3, "slice_3"));
// Notice that we leave a hole in the tensor
// . . . . .
@ -250,7 +149,7 @@ static void BM_RegisterOneByOne(int parts) {
TensorSliceSet slice_set(shape, DT_INT32);
for (int i = 0; i < parts; ++i) {
TensorSlice part({{i, 1}, {0, -1}});
TF_CHECK_OK(slice_set.Register(part, part.DebugString(), nullptr));
TF_CHECK_OK(slice_set.Register(part, part.DebugString()));
}
}