Added a format for saving an inference graph that can be memmapped and an utility to convert a freezed graph into this format.

Change: 120128412
This commit is contained in:
A. Unique TensorFlower 2016-04-18 08:08:23 -08:00 committed by TensorFlower Gardener
parent 517d3af445
commit 3c280f6fa0
20 changed files with 1219 additions and 27 deletions

View File

@ -7,6 +7,47 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
# Convertor of a frozen graph definition into the memmapped format.
cc_library(
name = "convert_graphdef_memmapped_format_lib",
srcs = ["convert_graphdef_memmapped_format_lib.cc"],
hdrs = ["convert_graphdef_memmapped_format_lib.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core/kernels:immutable_constant_op",
],
)
cc_binary(
name = "convert_graphdef_memmapped_format",
srcs = ["convert_graphdef_memmapped_format.cc"],
deps = [
":convert_graphdef_memmapped_format_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_test(
name = "convert_graphdef_memmapped_format_test",
srcs = ["convert_graphdef_memmapped_format_test.cc"],
deps = [
":convert_graphdef_memmapped_format_lib",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_binary(
name = "inspect_checkpoint",
srcs = ["inspect_checkpoint.cc"],

View File

@ -0,0 +1,88 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Utility that converts a "frozen" inference graph (output from the
// freeze_graph utility) into a format in which large Const ops are converted to
// ImmutableConst ops which are memmapped when the graph is executed by
// TensorFlow.
//
// tensorflow/contrib/util/convert_graphdef_memmapped_format
// --in_graph=frozen.model --out_graph=memmapped.mmodel
//
// Parameters:
// in_graph - name of a file with a frozen GraphDef proto in binary format
// out_graph - name of the output file, where the graph in memmapped format will
// be saved.
// min_conversion_size_bytes - tensors with fewer than this many bytes of data
// will not be converted to ImmutableConst format, and kept in the graph.
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace {
int ParseFlagsAndConvertGraph(int argc, char* argv[]) {
string in_graph = "";
string out_graph = "";
int min_conversion_tensor_size = 10000;
const bool parse_result = ParseFlags(
&argc, argv,
{// input graph
Flag("in_graph", &in_graph),
// output graph
Flag("out_graph", &out_graph),
// constants with tensors that have less than this number elements won't
// be converted into ImmutableConst (be memmapped).
Flag("min_conversion_tensor_size", &min_conversion_tensor_size)});
// We need to call this to set up global state for TensorFlow.
port::InitMain(argv[0], &argc, &argv);
if (!parse_result) {
LOG(ERROR) << "Error parsing command-line flags.";
return -1;
}
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1];
return -1;
}
if (in_graph.empty()) {
LOG(ERROR) << "in_graph graph can't be empty";
return -1;
}
if (out_graph.empty()) {
LOG(ERROR) << "out_graph graph can't be empty";
return -1;
}
if (min_conversion_tensor_size <= 0) {
LOG(ERROR) << "min_conversion_tensor_size must be > 0";
return -1;
}
const auto result = ConvertConstantsToImmutable(in_graph, out_graph,
min_conversion_tensor_size);
if (!result.ok()) {
LOG(ERROR) << "Conversion failed " << result.error_message();
return -1;
}
return 0;
}
} // namespace
} // namespace tensorflow
int main(int argc, char* argv[]) {
return tensorflow::ParseFlagsAndConvertGraph(argc, argv);
}

View File

@ -0,0 +1,156 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/kernels/immutable_constant_op.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/memmapped_file_system_writer.h"
namespace tensorflow {
namespace {
class NodeConverter {
public:
// Converts one node. In-place updates node_def, writes the tensor in
// memmapped
// format, using writer. If the conversion has been done, convert_counter is
// increased.
Status ConvertConstantsToImmutable(NodeDef* node_def,
MemmappedFileSystemWriter* writer,
int* convert_counter,
int min_conversion_size_bytes) {
// Check the size.
const AttrValue& value = node_def->attr().at("value");
const TensorProto& tensor_proto = value.tensor();
// Create copies of tensor datatype and shape, to put into the operator
// after
// the tensor is destroyed.
const DataType tensor_data_type = tensor_proto.dtype();
const TensorShapeProto tensor_shape = tensor_proto.tensor_shape();
// Create Tensor from value and write it in memmapped format.
Tensor parsed(tensor_proto.dtype());
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
tensor_proto.DebugString());
}
if (parsed.TotalBytes() < min_conversion_size_bytes) {
return Status::OK();
}
const string memmapped_region_name =
MemmappedFileSystem::kMemmappedPackagePrefix +
ConvertVariableNameToUniqueRegionName(node_def->name());
TF_RETURN_IF_ERROR(writer->SaveTensor(parsed, memmapped_region_name));
node_def->set_op("ImmutableConst");
// Erase all attributes and leave only attributes that can be understood by
// ImmutableConst.
auto* mutable_attr = node_def->mutable_attr();
mutable_attr->clear();
{
AttrValue attr_value;
attr_value.set_type(tensor_data_type);
mutable_attr->insert({ImmutableConstantOp::kDTypeAttr, attr_value});
}
{
AttrValue attr_value;
*(attr_value.mutable_shape()) = tensor_shape;
mutable_attr->insert({ImmutableConstantOp::kShapeAttr, attr_value});
}
{
AttrValue attr_value;
attr_value.set_s(memmapped_region_name);
mutable_attr->insert(
{ImmutableConstantOp::kMemoryRegionNameAttr, attr_value});
}
++*convert_counter;
return Status::OK();
}
private:
string ConvertVariableNameToUniqueRegionName(const string& variable_name) {
string region_name = SanitizeVariableName(variable_name);
while (!used_names_.insert(region_name).second) {
region_name += '_';
}
return region_name;
}
static string SanitizeVariableName(const string& variable_name) {
string result;
for (char c : variable_name) {
if ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') || c == '_' || c == '.') {
result += c;
} else {
result += '_';
}
}
return result;
}
std::unordered_set<string> used_names_;
};
} // namespace
// Loads the graph, replaces operators, and writes it out.
Status ConvertConstantsToImmutable(const string& in_graph_filename,
const string& out_graph_filename,
int min_conversion_size_bytes) {
Env* default_env = Env::Default();
GraphDef graph_def;
const auto load_graph_status =
ReadBinaryProto(default_env, in_graph_filename, &graph_def);
if (!load_graph_status.ok()) {
return tensorflow::errors::NotFound("Failed to load graph at '",
in_graph_filename, "' : ",
load_graph_status.error_message());
}
NodeConverter node_converter;
// Create output writer.
MemmappedFileSystemWriter writer;
TF_RETURN_IF_ERROR(writer.InitializeToFile(default_env, out_graph_filename));
// Iterate over graph nodes, looking for Const and replacing it with
// ImmutableConst.
int convert_counter = 0;
for (int i = 0; i < graph_def.node_size(); ++i) {
const NodeDef& node = graph_def.node(i);
if (node.op() == "Const") {
// Try to convert to ImmutableConst
TF_RETURN_IF_ERROR(node_converter.ConvertConstantsToImmutable(
graph_def.mutable_node(i), &writer, &convert_counter,
min_conversion_size_bytes));
}
}
TF_RETURN_IF_ERROR(writer.SaveProtobuf(
graph_def, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef));
TF_RETURN_IF_ERROR(writer.FlushAndClose());
LOG(INFO) << "Converted " << convert_counter << " nodes";
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,34 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
#include <string>
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Converts a "frozen" inference graph (output from the freeze_graph utility)
// into a format in which large Const ops are converted to ImmutableConst ops
// which are memmapped when the graph is executed by TensorFlow.
Status ConvertConstantsToImmutable(const string& in_graph_filename,
const string& out_graph_filename,
int min_conversion_size_bytes);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_

View File

@ -0,0 +1,84 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/memmapped_file_system.h"
namespace tensorflow {
namespace {
TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
const string dir = testing::TmpDir();
const string filename_pb = io::JoinPath(dir, "graphdef.pb");
// Create a simple graph and write it to filename_pb.
constexpr int kTensorWidth = 4000;
constexpr int kTensorHeight = 100;
const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight});
const TensorShape kTestTensorShapeT({kTensorHeight, kTensorWidth});
Tensor test_tensor1(DT_FLOAT, kTestTensorShape);
test::FillFn<float>(&test_tensor1, [](int) -> float { return 2.0; });
Tensor test_tensor2(DT_FLOAT, kTestTensorShapeT);
test::FillFn<float>(&test_tensor2, [](int) -> float { return 3.0; });
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* node1 = ops::Const(test_tensor1, b.opts());
Node* node2 = ops::Const(test_tensor2, b.opts());
const string result_name = ops::MatMul(node1, node2, b.opts())->name();
GraphDef graph_def;
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
string graph_def_serialized;
graph_def.SerializeToString(&graph_def_serialized);
TF_ASSERT_OK(
WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized));
const string filename_mmap = io::JoinPath(dir, "graphdef.mmap");
TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 10000));
// Create and initialize MemmappedEnv from the converted file.
MemmappedEnv memmapped_env(Env::Default());
TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap));
// Load the graph and run calculations.
SessionOptions session_options;
session_options.env = &memmapped_env;
std::unique_ptr<Session> session(NewSession(session_options));
ASSERT_TRUE(session != nullptr) << "Failed to create session";
GraphDef loaded_graph_def;
TF_ASSERT_OK(ReadBinaryProto(
&memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&loaded_graph_def));
TF_ASSERT_OK(session->Create(loaded_graph_def)) << "Can't create test graph";
std::vector<Tensor> outputs;
TF_ASSERT_OK(session->Run({}, {result_name + ":0"}, {}, &outputs));
ASSERT_EQ(outputs.size(), 1);
EXPECT_EQ(outputs.front().flat<float>()(0), 2.0f * 3.0f * kTensorHeight);
EXPECT_EQ(outputs.front().flat<float>()(1), 2.0f * 3.0f * kTensorHeight);
EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f * kTensorHeight);
}
} // namespace
} // namespace tensorflow

View File

@ -261,6 +261,8 @@ tf_cuda_library(
"util/device_name_utils.h",
"util/events_writer.h",
"util/guarded_philox_random.h",
"util/memmapped_file_system.h",
"util/memmapped_file_system_writer.h",
"util/mirror_pad_mode.h",
"util/padding.h",
"util/port.h",

View File

@ -43,6 +43,8 @@ string AllocatorStats::DebugString() const {
this->num_allocs, this->max_alloc_size);
}
constexpr size_t Allocator::kAllocatorAlignment;
Allocator::~Allocator() {}
// If true, cpu allocator collects more stats.

View File

@ -66,6 +66,9 @@ struct AllocatorStats {
// device memory.
class Allocator {
public:
// Align to 32 byte boundary.
static constexpr size_t kAllocatorAlignment = 32;
virtual ~Allocator();
// Return a string identifying this allocator
@ -112,8 +115,8 @@ class Allocator {
return NULL;
}
void* p = AllocateRaw(32 /* align to 32 byte boundary */,
sizeof(T) * num_elements, allocation_attr);
void* p = AllocateRaw(kAllocatorAlignment, sizeof(T) * num_elements,
allocation_attr);
T* typed_p = reinterpret_cast<T*>(p);
if (typed_p) RunCtor<T>(typed_p, num_elements);
return typed_p;
@ -192,11 +195,10 @@ class Allocator {
// without running their default ctors and dtors.
template <typename T>
struct is_simple {
static const bool value = std::is_trivial<T>::value ||
std::is_same<T, Eigen::half>::value ||
std::is_same<T, complex64>::value ||
std::is_same<T, complex128>::value ||
is_quantized<T>::value;
static constexpr bool value =
std::is_trivial<T>::value || std::is_same<T, Eigen::half>::value ||
std::is_same<T, complex64>::value ||
std::is_same<T, complex128>::value || is_quantized<T>::value;
};
// Fills in 'stats' with statistics collected by this allocator.

View File

@ -20,11 +20,12 @@ namespace tensorflow {
ImmutableConstantOp::ImmutableConstantOp(OpKernelConstruction* context)
: OpKernel(context) {
::tensorflow::DataType dtype;
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype));
OP_REQUIRES_OK(context, context->GetAttr(kDTypeAttr, &dtype));
::tensorflow::TensorShape shape;
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape));
OP_REQUIRES_OK(context, context->GetAttr(kShapeAttr, &shape));
string region_name;
OP_REQUIRES_OK(context, context->GetAttr("memory_region_name", &region_name));
OP_REQUIRES_OK(context,
context->GetAttr(kMemoryRegionNameAttr, &region_name));
OP_REQUIRES_OK(context,
allocator_.InitWithMemoryRegion(region_name, context->env()));
tensor_ = ::tensorflow::Tensor(&allocator_, dtype, shape);
@ -90,6 +91,10 @@ void ImmutableConstantOp::ReadOnlyMemoryRegionAllocator::DeallocateRaw(
}
}
constexpr char ImmutableConstantOp::kDTypeAttr[];
constexpr char ImmutableConstantOp::kShapeAttr[];
constexpr char ImmutableConstantOp::kMemoryRegionNameAttr[];
REGISTER_KERNEL_BUILDER(Name("ImmutableConst").Device(DEVICE_CPU),
ImmutableConstantOp);
} // namespace tensorflow

View File

@ -32,6 +32,11 @@ class ImmutableConstantOp : public OpKernel {
bool IsExpensive() override { return false; }
~ImmutableConstantOp() override;
// Names of attributes that are used by this op
static constexpr char kDTypeAttr[] = "dtype";
static constexpr char kShapeAttr[] = "shape";
static constexpr char kMemoryRegionNameAttr[] = "memory_region_name";
private:
class ReadOnlyMemoryRegionAllocator : public ::tensorflow::Allocator {
public:

View File

@ -23,7 +23,7 @@ namespace tensorflow {
class FileSystemRegistryImpl : public FileSystemRegistry {
public:
void Register(const string& scheme, Factory factory) override;
Status Register(const string& scheme, Factory factory) override;
FileSystem* Lookup(const string& scheme) override;
Status GetRegisteredFileSystemSchemes(std::vector<string>* schemes) override;
@ -33,13 +33,15 @@ class FileSystemRegistryImpl : public FileSystemRegistry {
GUARDED_BY(mu_);
};
void FileSystemRegistryImpl::Register(const string& scheme,
FileSystemRegistry::Factory factory) {
Status FileSystemRegistryImpl::Register(const string& scheme,
FileSystemRegistry::Factory factory) {
mutex_lock lock(mu_);
QCHECK(
registry_.emplace(string(scheme), std::unique_ptr<FileSystem>(factory()))
.second)
<< "File factory for " << scheme << " already registered";
if (!registry_.emplace(string(scheme), std::unique_ptr<FileSystem>(factory()))
.second) {
return errors::AlreadyExists("File factory for ", scheme,
" already registered");
}
return Status::OK();
}
FileSystem* FileSystemRegistryImpl::Lookup(const string& scheme) {
@ -77,9 +79,9 @@ Status Env::GetRegisteredFileSystemSchemes(std::vector<string>* schemes) {
return file_system_registry_->GetRegisteredFileSystemSchemes(schemes);
}
void Env::RegisterFileSystem(const string& scheme,
FileSystemRegistry::Factory factory) {
file_system_registry_->Register(scheme, factory);
Status Env::RegisterFileSystem(const string& scheme,
FileSystemRegistry::Factory factory) {
return file_system_registry_->Register(scheme, factory);
}
Status Env::NewRandomAccessFile(const string& fname,

View File

@ -68,8 +68,8 @@ class Env {
virtual Status GetRegisteredFileSystemSchemes(std::vector<string>* schemes);
// \brief Register a file system for a scheme.
virtual void RegisterFileSystem(const string& scheme,
FileSystemRegistry::Factory factory);
virtual Status RegisterFileSystem(const string& scheme,
FileSystemRegistry::Factory factory);
/// \brief Creates a brand new random access read-only file with the
/// specified name.
@ -236,9 +236,9 @@ class EnvWrapper : public Env {
return target_->GetRegisteredFileSystemSchemes(schemes);
}
void RegisterFileSystem(const string& scheme,
FileSystemRegistry::Factory factory) override {
target_->RegisterFileSystem(scheme, factory);
Status RegisterFileSystem(const string& scheme,
FileSystemRegistry::Factory factory) override {
return target_->RegisterFileSystem(scheme, factory);
}
uint64 NowMicros() override { return target_->NowMicros(); }

View File

@ -34,7 +34,7 @@ class RandomAccessFile;
class ReadOnlyMemoryRegion;
class WritableFile;
/// An generic interface for accessing a file system.
/// A generic interface for accessing a file system.
class FileSystem {
public:
FileSystem() {}
@ -202,7 +202,7 @@ class FileSystemRegistry {
typedef std::function<FileSystem*()> Factory;
virtual ~FileSystemRegistry();
virtual void Register(const string& scheme, Factory factory) = 0;
virtual Status Register(const string& scheme, Factory factory) = 0;
virtual FileSystem* Lookup(const string& scheme) = 0;
virtual Status GetRegisteredFileSystemSchemes(
std::vector<string>* schemes) = 0;

View File

@ -0,0 +1,281 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/memmapped_file_system.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/memmapped_file_system.pb.h"
namespace tensorflow {
namespace {
uint64 DecodeUint64LittleEndian(const uint8* buffer) {
uint64 result = 0;
for (int i = 0; i < static_cast<int>(sizeof(uint64)); ++i) {
result |= buffer[i] << (8 * i);
}
return result;
}
} // namespace
namespace {
class ReadOnlyMemoryRegionFromMemmapped : public ReadOnlyMemoryRegion {
public:
ReadOnlyMemoryRegionFromMemmapped(const void* data, uint64 length)
: data_(data), length_(length) {}
~ReadOnlyMemoryRegionFromMemmapped() override = default;
const void* data() override { return data_; }
uint64 length() override { return length_; }
private:
const void* const data_;
const uint64 length_;
// intentionally copyable
};
class RandomAccessFileFromMemmapped : public RandomAccessFile {
public:
RandomAccessFileFromMemmapped(const void* data, uint64 length)
: data_(data), length_(length) {}
~RandomAccessFileFromMemmapped() override = default;
Status Read(uint64 offset, size_t to_read, StringPiece* result,
char* scratch) const override {
if (offset >= length_) {
result->set(scratch, 0);
return Status(error::OUT_OF_RANGE, "Read after file end");
}
const uint64 region_left =
std::min(length_ - offset, static_cast<uint64>(to_read));
result->set(reinterpret_cast<const uint8*>(data_) + offset, region_left);
return (region_left == to_read)
? Status::OK()
: Status(error::OUT_OF_RANGE, "Read less bytes than requested");
}
private:
const void* const data_;
const uint64 length_;
// intentionally copyable
};
} // namespace
MemmappedFileSystem::MemmappedFileSystem() {}
bool MemmappedFileSystem::FileExists(const string& fname) {
if (!mapped_memory_) {
return false;
}
const auto dir_element = directory_.find(fname);
return dir_element != directory_.end();
}
Status MemmappedFileSystem::NewRandomAccessFile(const string& filename,
RandomAccessFile** result) {
if (!mapped_memory_) {
return errors::FailedPrecondition("MemmappedEnv is not initialized");
}
const auto dir_element = directory_.find(filename);
if (dir_element == directory_.end()) {
return errors::NotFound("Region ", filename, " is not found");
}
*result = new RandomAccessFileFromMemmapped(
GetMemoryWithOffset(dir_element->second.offset),
dir_element->second.length);
return Status::OK();
}
Status MemmappedFileSystem::NewReadOnlyMemoryRegionFromFile(
const string& filename, ReadOnlyMemoryRegion** result) {
if (!mapped_memory_) {
return errors::FailedPrecondition("MemmappedEnv is not initialized");
}
const auto dir_element = directory_.find(filename);
if (dir_element == directory_.end()) {
return errors::NotFound("Region ", filename, " is not found");
}
*result = new ReadOnlyMemoryRegionFromMemmapped(
GetMemoryWithOffset(dir_element->second.offset),
dir_element->second.length);
return Status::OK();
}
Status MemmappedFileSystem::GetFileSize(const string& filename, uint64* size) {
if (!mapped_memory_) {
return errors::FailedPrecondition("MemmappedEnv is not initialized");
}
const auto dir_element = directory_.find(filename);
if (dir_element == directory_.end()) {
return errors::NotFound("Region ", filename, " is not found");
}
*size = dir_element->second.length;
return Status::OK();
}
Status MemmappedFileSystem::NewWritableFile(const string& filename,
WritableFile** wf) {
return errors::Unimplemented("memmapped format doesn't support writing");
}
Status MemmappedFileSystem::NewAppendableFile(const string& filename,
WritableFile** result) {
return errors::Unimplemented("memmapped format doesn't support writing");
}
Status MemmappedFileSystem::GetChildren(const string& filename,
std::vector<string>* strings) {
return errors::Unimplemented("memmapped format doesn't support GetChildren");
}
Status MemmappedFileSystem::DeleteFile(const string& filename) {
return errors::Unimplemented("memmapped format doesn't support DeleteFile");
}
Status MemmappedFileSystem::CreateDir(const string& dirname) {
return errors::Unimplemented("memmapped format doesn't support CreateDir");
}
Status MemmappedFileSystem::DeleteDir(const string& dirname) {
return errors::Unimplemented("memmapped format doesn't support DeleteDir");
}
Status MemmappedFileSystem::RenameFile(const string& filename_from,
const string& filename_to) {
return errors::Unimplemented("memmapped format doesn't support RenameFile");
}
const void* MemmappedFileSystem::GetMemoryWithOffset(uint64 offset) const {
return reinterpret_cast<const uint8*>(mapped_memory_->data()) + offset;
}
constexpr char MemmappedFileSystem::kMemmappedPackagePrefix[];
constexpr char MemmappedFileSystem::kMemmappedPackageDefaultGraphDef[];
Status MemmappedFileSystem::InitializeFromFile(Env* env,
const string& filename) {
ReadOnlyMemoryRegion* region;
TF_RETURN_IF_ERROR(env->NewReadOnlyMemoryRegionFromFile(filename, &region));
mapped_memory_.reset(region);
directory_.clear();
if (mapped_memory_->length() <= sizeof(uint64)) {
return errors::DataLoss("Corrupted memmapped model file: ", filename,
" Invalid package size");
}
const auto memory_start =
reinterpret_cast<const uint8*>(mapped_memory_->data());
const uint64 directory_offset = DecodeUint64LittleEndian(
memory_start + mapped_memory_->length() - sizeof(uint64));
if (directory_offset > mapped_memory_->length() - sizeof(uint64)) {
return errors::DataLoss("Corrupted memmapped model file: ", filename,
" Invalid directory offset");
}
MemmappedFileSystemDirectory proto_directory;
if (!ParseProtoUnlimited(
&proto_directory, memory_start + directory_offset,
mapped_memory_->length() - directory_offset - sizeof(uint64))) {
return errors::DataLoss("Corrupted memmapped model file: ", filename,
" Can't parse its internal directory");
}
// Iterating in reverse order to get lengths of elements;
uint64 prev_element_offset = directory_offset;
for (auto element_iter = proto_directory.element().rbegin();
element_iter != proto_directory.element().rend(); ++element_iter) {
// Check that the element offset is in the right range.
if (element_iter->offset() >= prev_element_offset) {
return errors::DataLoss("Corrupted memmapped model file: ", filename,
" Invalid offset of internal component");
}
if (!directory_
.insert(std::make_pair(
element_iter->name(),
FileRegion(element_iter->offset(),
prev_element_offset - element_iter->offset())))
.second) {
return errors::DataLoss("Corrupted memmapped model file: ", filename,
" Duplicate name of internal component ",
element_iter->name());
}
prev_element_offset = element_iter->offset();
}
return Status::OK();
}
bool MemmappedFileSystem::IsMemmappedPackageFilename(const string& filename) {
return StringPiece(filename).starts_with(kMemmappedPackagePrefix);
}
namespace {
bool IsValidRegionChar(char c) {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') || c == '_' || c == '.';
}
} // namespace
bool MemmappedFileSystem::IsWellFormedMemmappedPackageFilename(
const string& filename) {
if (!IsMemmappedPackageFilename(filename)) {
return false;
}
for (char c :
filename.substr(strlen(kMemmappedPackagePrefix),
filename.length() - strlen(kMemmappedPackagePrefix))) {
if (!IsValidRegionChar(c)) {
return false;
}
}
return true;
}
MemmappedEnv::MemmappedEnv(Env* env) : EnvWrapper(env) {}
Status MemmappedEnv::GetFileSystemForFile(const string& fname,
FileSystem** result) {
if (MemmappedFileSystem::IsMemmappedPackageFilename(fname)) {
if (!memmapped_file_system_) {
return errors::FailedPrecondition(
"MemmappedEnv is not initialized from a file.");
}
*result = memmapped_file_system_.get();
return Status::OK();
}
return EnvWrapper::GetFileSystemForFile(fname, result);
}
Status MemmappedEnv::GetRegisteredFileSystemSchemes(
std::vector<string>* schemes) {
const auto status = EnvWrapper::GetRegisteredFileSystemSchemes(schemes);
if (status.ok()) {
schemes->emplace_back(MemmappedFileSystem::kMemmappedPackagePrefix);
}
return status;
}
Status MemmappedEnv::InitializeFromFile(const string& package_filename) {
std::unique_ptr<MemmappedFileSystem> file_system_ptr(new MemmappedFileSystem);
const auto status =
file_system_ptr->InitializeFromFile(target(), package_filename);
if (status.ok()) {
memmapped_file_system_ = std::move(file_system_ptr);
}
return status;
}
} // namespace tensorflow

View File

@ -0,0 +1,121 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_H_
#define TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_H_
#include <memory>
#include <string>
#include <unordered_map>
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
// A file system that uses a graph saved in memmapped format by
// MemmappedEnvWriter as a file system.
//
// The format supports saved tensors and protos. Tensors are saved at aligned
// offsets.
//
// Format specification:
// - last 8 bytes of a package is encoded offset to the directory. The encoding
// is always little endian, independently from the platform, done by functions
// EncodeUint64LittleEndian/DecodeUint64LittleEndian
// - the directory starts from the encoded offset and is saved proto
// MemmappedFileSystemDirectory with names and offsets to the regions.
// - at the offsets in the directory the file regions are stored. Tensor regions
// are aligned such way that when the package mapped to RAM they have the right
// offset to be used by ImmutableConst operator.
//
// Region naming:
// Region naming is up to the application, all of them starts from
// kMemmappedPackagePrefix. The default graph usually has name
// kMemmappedPackageDefaultGraphDef; for more details see the conversion
// utility
// third_party/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc
//
// A "frozen" GraphDef can be converted into this format using
// tensorflow/contrib/util/convert_graphdef_memmapped_format
class MemmappedFileSystem : public FileSystem {
public:
// Memmapped regions use this prefix to distinguish from
// the filesystem.
static constexpr char kMemmappedPackagePrefix[] = "memmapped_package://";
// The default graphdef in the package.
static constexpr char kMemmappedPackageDefaultGraphDef[] =
"memmapped_package://.";
MemmappedFileSystem();
~MemmappedFileSystem() override = default;
bool FileExists(const string& fname) override;
Status NewRandomAccessFile(const string& filename,
RandomAccessFile** result) override;
Status NewReadOnlyMemoryRegionFromFile(
const string& filename, ReadOnlyMemoryRegion** result) override;
// All these functions return Unimplemented error, the memmapped storage is
// read only.
Status NewWritableFile(const string& fname, WritableFile** result) override;
Status NewAppendableFile(const string& fname, WritableFile** result) override;
Status GetChildren(const string& dir, std::vector<string>* r) override;
Status DeleteFile(const string& f) override;
Status CreateDir(const string& d) override;
Status DeleteDir(const string& d) override;
Status GetFileSize(const string& f, uint64* s) override;
Status RenameFile(const string& s, const string& t) override;
// Initializes filesystem from a file in memmapped format.
Status InitializeFromFile(Env* env, const string& filename);
// Checks if the filename has a correct prefix.
static bool IsMemmappedPackageFilename(const string& filename);
static bool IsWellFormedMemmappedPackageFilename(const string& filename);
private:
struct FileRegion {
FileRegion(uint64 o, uint64 l) : offset(o), length(l) {}
uint64 offset; // Offset from the beginning of the file.
uint64 length; // Length of the region.
};
using DirectoryType = std::unordered_map<string, FileRegion>;
const void* GetMemoryWithOffset(uint64 offset) const;
std::unique_ptr<ReadOnlyMemoryRegion> mapped_memory_;
DirectoryType directory_;
TF_DISALLOW_COPY_AND_ASSIGN(MemmappedFileSystem);
};
class MemmappedEnv : public EnvWrapper {
public:
explicit MemmappedEnv(Env* env);
~MemmappedEnv() override = default;
Status GetFileSystemForFile(const string& fname,
FileSystem** result) override;
Status GetRegisteredFileSystemSchemes(std::vector<string>* schemes) override;
Status InitializeFromFile(const string& filename);
protected:
std::unique_ptr<MemmappedFileSystem> memmapped_file_system_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_H_

View File

@ -0,0 +1,29 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package tensorflow;
// option cc_enable_arenas = true;
// A message that describes one region of memmapped file.
message MemmappedFileSystemDirectoryElement {
uint64 offset = 1;
string name = 2;
}
// A directory of regions in a memmapped file.
message MemmappedFileSystemDirectory {
repeated MemmappedFileSystemDirectoryElement element = 1;
}

View File

@ -0,0 +1,150 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/memmapped_file_system.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/memmapped_file_system_writer.h"
namespace tensorflow {
namespace {
// Names of files in memmapped environment.
constexpr char kTensor1FileName[] = "memmapped_package://t1";
constexpr char kTensor2FileName[] = "memmapped_package://t2";
constexpr char kProtoFileName[] = "memmapped_package://b";
constexpr int kTestGraphDefVersion = 666;
Status CreateMemmappedFileSystemFile(const string& filename, bool corrupted,
Tensor* test_tensor) {
Env* env = Env::Default();
MemmappedFileSystemWriter writer;
TF_RETURN_IF_ERROR(writer.InitializeToFile(env, filename));
// Try to write a tensor and proto.
test::FillFn<float>(test_tensor,
[](int i) { return static_cast<float>(i * i); });
TF_RETURN_IF_ERROR(writer.SaveTensor(*test_tensor, kTensor1FileName));
// Create a proto with some fields.
GraphDef graph_def;
graph_def.set_version(kTestGraphDefVersion);
TF_RETURN_IF_ERROR(writer.SaveProtobuf(graph_def, kProtoFileName));
// Save a tensor after the proto to check that alignment works.
test::FillFn<float>(test_tensor,
[](int i) { return static_cast<float>(i * i * i); });
TF_RETURN_IF_ERROR(writer.SaveTensor(*test_tensor, kTensor2FileName));
if (!corrupted) {
// Flush and close the file.
TF_RETURN_IF_ERROR(writer.FlushAndClose());
}
return Status::OK();
}
TEST(MemmappedFileSystemTest, SimpleTest) {
const TensorShape test_tensor_shape = {10, 200};
Tensor test_tensor(DT_FLOAT, test_tensor_shape);
const string dir = testing::TmpDir();
const string filename = io::JoinPath(dir, "memmapped_env_test");
TF_ASSERT_OK(CreateMemmappedFileSystemFile(filename, false, &test_tensor));
// Check that we can memmap the created file.
MemmappedEnv memmapped_env(Env::Default());
TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename));
// Try to load a proto from the file.
GraphDef test_graph_def;
TF_EXPECT_OK(
ReadBinaryProto(&memmapped_env, kProtoFileName, &test_graph_def));
EXPECT_EQ(kTestGraphDefVersion, test_graph_def.version());
// Check that we can correctly get a tensor memory.
ReadOnlyMemoryRegion* memory_region;
TF_ASSERT_OK(memmapped_env.NewReadOnlyMemoryRegionFromFile(kTensor2FileName,
&memory_region));
std::unique_ptr<ReadOnlyMemoryRegion> mem_region_ptr(memory_region);
// The memory region can be bigger but not less than Tensor size.
ASSERT_GE(memory_region->length(), test_tensor.TotalBytes());
EXPECT_EQ(test_tensor.tensor_data(),
StringPiece(static_cast<const char*>(memory_region->data()),
test_tensor.TotalBytes()));
// Check that GetFileSize works.
uint64 file_size = 0;
TF_ASSERT_OK(memmapped_env.GetFileSize(kTensor2FileName, &file_size));
EXPECT_EQ(test_tensor.TotalBytes(), file_size);
// Check that if file not found correct error message returned.
EXPECT_EQ(
error::NOT_FOUND,
memmapped_env.NewReadOnlyMemoryRegionFromFile("bla-bla", &memory_region)
.code());
// Check FileExists.
EXPECT_TRUE(memmapped_env.FileExists(kTensor2FileName));
EXPECT_FALSE(memmapped_env.FileExists("bla-bla-bla"));
}
TEST(MemmappedFileSystemTest, NotInitalized) {
MemmappedEnv memmapped_env(Env::Default());
ReadOnlyMemoryRegion* memory_region;
EXPECT_EQ(
error::FAILED_PRECONDITION,
memmapped_env
.NewReadOnlyMemoryRegionFromFile(kTensor1FileName, &memory_region)
.code());
RandomAccessFile* file;
EXPECT_EQ(error::FAILED_PRECONDITION,
memmapped_env.NewRandomAccessFile(kProtoFileName, &file).code());
}
TEST(MemmappedFileSystemTest, Corrupted) {
// Create a corrupted file (it is not closed it properly).
const TensorShape test_tensor_shape = {100, 200};
Tensor test_tensor(DT_FLOAT, test_tensor_shape);
const string dir = testing::TmpDir();
const string filename = io::JoinPath(dir, "memmapped_env_corrupted_test");
TF_ASSERT_OK(CreateMemmappedFileSystemFile(filename, true, &test_tensor));
MemmappedFileSystem memmapped_env;
ASSERT_NE(memmapped_env.InitializeFromFile(Env::Default(), filename),
Status::OK());
}
TEST(MemmappedFileSystemTest, ProxyToDefault) {
MemmappedEnv memmapped_env(Env::Default());
const string dir = testing::TmpDir();
const string filename = io::JoinPath(dir, "test_file");
// Check that we can create write and read ordinary file.
WritableFile* writable_file;
TF_ASSERT_OK(memmapped_env.NewAppendableFile(filename, &writable_file));
std::unique_ptr<WritableFile> writable_file_ptr(writable_file);
const string test_string = "bla-bla-bla";
TF_ASSERT_OK(writable_file->Append(test_string));
TF_ASSERT_OK(writable_file->Close());
uint64 file_length = 0;
TF_EXPECT_OK(memmapped_env.GetFileSize(filename, &file_length));
EXPECT_EQ(test_string.length(), file_length);
RandomAccessFile* random_access_file;
TF_ASSERT_OK(
memmapped_env.NewRandomAccessFile(filename, &random_access_file));
delete random_access_file;
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,136 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/memmapped_file_system_writer.h"
#include <algorithm>
namespace tensorflow {
Status MemmappedFileSystemWriter::InitializeToFile(Env* env,
const string& filename) {
WritableFile* writable_file;
auto status = env->NewWritableFile(filename, &writable_file);
if (status.ok()) {
output_file_.reset(writable_file);
output_file_offset_ = 0;
}
return status;
}
Status MemmappedFileSystemWriter::SaveTensor(const Tensor& tensor,
const string& element_name) {
if (!output_file_) {
return errors::FailedPrecondition(
"MemmappedEnvWritter: saving tensor into not opened file");
}
if (!MemmappedFileSystem::IsWellFormedMemmappedPackageFilename(
element_name)) {
return errors::InvalidArgument(
"MemmappedEnvWritter: element_name is invalid: must have memmapped ",
"package prefix ", MemmappedFileSystem::kMemmappedPackagePrefix,
" and include [A-Za-z0-9_.]");
}
const auto tensor_data = tensor.tensor_data();
if (0 == tensor_data.size()) {
return errors::InvalidArgument(
"MemmappedEnvWritter: saving tensor with 0 size");
}
// Adds pad for correct alignment after memmapping.
TF_RETURN_IF_ERROR(AdjustAlignment(Allocator::kAllocatorAlignment));
AddToDirectoryElement(element_name);
const auto result = output_file_->Append(tensor_data);
if (result.ok()) {
output_file_offset_ += tensor_data.size();
}
return result;
}
Status MemmappedFileSystemWriter::SaveProtobuf(
const protobuf::MessageLite& message, const string& element_name) {
if (!output_file_) {
return errors::FailedPrecondition(
"MemmappedEnvWritter: saving protobuf into not opened file");
}
if (!MemmappedFileSystem::IsWellFormedMemmappedPackageFilename(
element_name)) {
return errors::InvalidArgument(
"MemmappedEnvWritter: element_name is invalid: must have memmapped "
"package prefix ",
MemmappedFileSystem::kMemmappedPackagePrefix,
" and include [A-Za-z0-9_.]");
}
AddToDirectoryElement(element_name);
const string encoded = message.SerializeAsString();
const auto res = output_file_->Append(encoded);
if (res.ok()) {
output_file_offset_ += encoded.size();
}
return res;
}
namespace {
StringPiece EncodeUint64LittleEndian(uint64 val, char* output_buffer) {
for (int i = 0; i < sizeof(uint64); ++i) {
output_buffer[i] = (val >> i * 8);
}
return {output_buffer, sizeof(uint64)};
}
} // namespace
Status MemmappedFileSystemWriter::FlushAndClose() {
if (!output_file_) {
return errors::FailedPrecondition(
"MemmappedEnvWritter: flushing into not opened file");
}
const string dir = directory_.SerializeAsString();
TF_RETURN_IF_ERROR(output_file_->Append(dir));
// Write the directory offset.
char buffer[sizeof(uint64)];
TF_RETURN_IF_ERROR(output_file_->Append(
EncodeUint64LittleEndian(output_file_offset_, buffer)));
// Flush and close the file.
TF_RETURN_IF_ERROR(output_file_->Flush());
TF_RETURN_IF_ERROR(output_file_->Close());
output_file_.reset();
return Status::OK();
}
Status MemmappedFileSystemWriter::AdjustAlignment(uint64 alignment) {
const uint64 alignment_rest = output_file_offset_ % alignment;
const uint64 to_write_for_alignment =
(alignment_rest == 0) ? 0 : alignment - (output_file_offset_ % alignment);
static constexpr uint64 kFillerBufferSize = 16;
const char kFillerBuffer[kFillerBufferSize] = {};
for (uint64 rest = to_write_for_alignment; rest > 0;) {
StringPiece sp(kFillerBuffer, std::min(rest, kFillerBufferSize));
TF_RETURN_IF_ERROR(output_file_->Append(sp));
rest -= sp.size();
output_file_offset_ += sp.size();
}
return Status::OK();
}
void MemmappedFileSystemWriter::AddToDirectoryElement(const string& name) {
MemmappedFileSystemDirectoryElement* new_directory_element =
directory_.add_element();
new_directory_element->set_offset(output_file_offset_);
new_directory_element->set_name(name);
}
} // namespace tensorflow

View File

@ -0,0 +1,53 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_WRITER_H_
#define TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_WRITER_H_
#include <memory>
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/util/memmapped_file_system.h"
#include "tensorflow/core/util/memmapped_file_system.pb.h"
namespace tensorflow {
// A class for saving into the memmapped format that can be read by
// MemmappedFileSystem.
class MemmappedFileSystemWriter {
public:
MemmappedFileSystemWriter() = default;
~MemmappedFileSystemWriter() = default;
Status InitializeToFile(Env* env, const string& filename);
Status SaveTensor(const Tensor& tensor, const string& element_name);
Status SaveProtobuf(const protobuf::MessageLite& message,
const string& element_name);
// Writes out the directory of regions and closes the output file.
Status FlushAndClose();
private:
Status AdjustAlignment(uint64 alignment);
void AddToDirectoryElement(const string& element_name);
MemmappedFileSystemDirectory directory_;
// The current offset in the file, to support alignment.
uint64 output_file_offset_ = 0;
std::unique_ptr<WritableFile> output_file_;
TF_DISALLOW_COPY_AND_ASSIGN(MemmappedFileSystemWriter);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_WRITER_H_

View File

@ -61,6 +61,7 @@ def tf_android_core_proto_sources_relative():
"lib/core/error_codes.proto",
"protobuf/config.proto",
"protobuf/saver.proto",
"util/memmapped_file_system.proto",
"util/saved_tensor_slice.proto",
"util/test_log.proto",
]