Added a format for saving an inference graph that can be memmapped and an utility to convert a freezed graph into this format.
Change: 120128412
This commit is contained in:
parent
517d3af445
commit
3c280f6fa0
@ -7,6 +7,47 @@ exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
# Convertor of a frozen graph definition into the memmapped format.
|
||||
cc_library(
|
||||
name = "convert_graphdef_memmapped_format_lib",
|
||||
srcs = ["convert_graphdef_memmapped_format_lib.cc"],
|
||||
hdrs = ["convert_graphdef_memmapped_format_lib.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core/kernels:immutable_constant_op",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "convert_graphdef_memmapped_format",
|
||||
srcs = ["convert_graphdef_memmapped_format.cc"],
|
||||
deps = [
|
||||
":convert_graphdef_memmapped_format_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "convert_graphdef_memmapped_format_test",
|
||||
srcs = ["convert_graphdef_memmapped_format_test.cc"],
|
||||
deps = [
|
||||
":convert_graphdef_memmapped_format_lib",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensor_testutil",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "inspect_checkpoint",
|
||||
srcs = ["inspect_checkpoint.cc"],
|
||||
|
88
tensorflow/contrib/util/convert_graphdef_memmapped_format.cc
Normal file
88
tensorflow/contrib/util/convert_graphdef_memmapped_format.cc
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Utility that converts a "frozen" inference graph (output from the
|
||||
// freeze_graph utility) into a format in which large Const ops are converted to
|
||||
// ImmutableConst ops which are memmapped when the graph is executed by
|
||||
// TensorFlow.
|
||||
//
|
||||
// tensorflow/contrib/util/convert_graphdef_memmapped_format
|
||||
// --in_graph=frozen.model --out_graph=memmapped.mmodel
|
||||
//
|
||||
// Parameters:
|
||||
// in_graph - name of a file with a frozen GraphDef proto in binary format
|
||||
// out_graph - name of the output file, where the graph in memmapped format will
|
||||
// be saved.
|
||||
// min_conversion_size_bytes - tensors with fewer than this many bytes of data
|
||||
// will not be converted to ImmutableConst format, and kept in the graph.
|
||||
|
||||
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
int ParseFlagsAndConvertGraph(int argc, char* argv[]) {
|
||||
string in_graph = "";
|
||||
string out_graph = "";
|
||||
int min_conversion_tensor_size = 10000;
|
||||
const bool parse_result = ParseFlags(
|
||||
&argc, argv,
|
||||
{// input graph
|
||||
Flag("in_graph", &in_graph),
|
||||
// output graph
|
||||
Flag("out_graph", &out_graph),
|
||||
// constants with tensors that have less than this number elements won't
|
||||
// be converted into ImmutableConst (be memmapped).
|
||||
Flag("min_conversion_tensor_size", &min_conversion_tensor_size)});
|
||||
// We need to call this to set up global state for TensorFlow.
|
||||
port::InitMain(argv[0], &argc, &argv);
|
||||
if (!parse_result) {
|
||||
LOG(ERROR) << "Error parsing command-line flags.";
|
||||
return -1;
|
||||
}
|
||||
if (argc > 1) {
|
||||
LOG(ERROR) << "Unknown argument " << argv[1];
|
||||
return -1;
|
||||
}
|
||||
if (in_graph.empty()) {
|
||||
LOG(ERROR) << "in_graph graph can't be empty";
|
||||
return -1;
|
||||
}
|
||||
if (out_graph.empty()) {
|
||||
LOG(ERROR) << "out_graph graph can't be empty";
|
||||
return -1;
|
||||
}
|
||||
if (min_conversion_tensor_size <= 0) {
|
||||
LOG(ERROR) << "min_conversion_tensor_size must be > 0";
|
||||
return -1;
|
||||
}
|
||||
const auto result = ConvertConstantsToImmutable(in_graph, out_graph,
|
||||
min_conversion_tensor_size);
|
||||
if (!result.ok()) {
|
||||
LOG(ERROR) << "Conversion failed " << result.error_message();
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return tensorflow::ParseFlagsAndConvertGraph(argc, argv);
|
||||
}
|
156
tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc
Normal file
156
tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc
Normal file
@ -0,0 +1,156 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
|
||||
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/kernels/immutable_constant_op.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system_writer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
class NodeConverter {
|
||||
public:
|
||||
// Converts one node. In-place updates node_def, writes the tensor in
|
||||
// memmapped
|
||||
// format, using writer. If the conversion has been done, convert_counter is
|
||||
// increased.
|
||||
Status ConvertConstantsToImmutable(NodeDef* node_def,
|
||||
MemmappedFileSystemWriter* writer,
|
||||
int* convert_counter,
|
||||
int min_conversion_size_bytes) {
|
||||
// Check the size.
|
||||
const AttrValue& value = node_def->attr().at("value");
|
||||
const TensorProto& tensor_proto = value.tensor();
|
||||
|
||||
// Create copies of tensor datatype and shape, to put into the operator
|
||||
// after
|
||||
// the tensor is destroyed.
|
||||
const DataType tensor_data_type = tensor_proto.dtype();
|
||||
const TensorShapeProto tensor_shape = tensor_proto.tensor_shape();
|
||||
|
||||
// Create Tensor from value and write it in memmapped format.
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
tensor_proto.DebugString());
|
||||
}
|
||||
if (parsed.TotalBytes() < min_conversion_size_bytes) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const string memmapped_region_name =
|
||||
MemmappedFileSystem::kMemmappedPackagePrefix +
|
||||
ConvertVariableNameToUniqueRegionName(node_def->name());
|
||||
|
||||
TF_RETURN_IF_ERROR(writer->SaveTensor(parsed, memmapped_region_name));
|
||||
|
||||
node_def->set_op("ImmutableConst");
|
||||
|
||||
// Erase all attributes and leave only attributes that can be understood by
|
||||
// ImmutableConst.
|
||||
auto* mutable_attr = node_def->mutable_attr();
|
||||
mutable_attr->clear();
|
||||
|
||||
{
|
||||
AttrValue attr_value;
|
||||
attr_value.set_type(tensor_data_type);
|
||||
mutable_attr->insert({ImmutableConstantOp::kDTypeAttr, attr_value});
|
||||
}
|
||||
{
|
||||
AttrValue attr_value;
|
||||
*(attr_value.mutable_shape()) = tensor_shape;
|
||||
mutable_attr->insert({ImmutableConstantOp::kShapeAttr, attr_value});
|
||||
}
|
||||
{
|
||||
AttrValue attr_value;
|
||||
attr_value.set_s(memmapped_region_name);
|
||||
mutable_attr->insert(
|
||||
{ImmutableConstantOp::kMemoryRegionNameAttr, attr_value});
|
||||
}
|
||||
++*convert_counter;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
string ConvertVariableNameToUniqueRegionName(const string& variable_name) {
|
||||
string region_name = SanitizeVariableName(variable_name);
|
||||
while (!used_names_.insert(region_name).second) {
|
||||
region_name += '_';
|
||||
}
|
||||
return region_name;
|
||||
}
|
||||
|
||||
static string SanitizeVariableName(const string& variable_name) {
|
||||
string result;
|
||||
for (char c : variable_name) {
|
||||
if ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
|
||||
(c >= '0' && c <= '9') || c == '_' || c == '.') {
|
||||
result += c;
|
||||
} else {
|
||||
result += '_';
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
std::unordered_set<string> used_names_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Loads the graph, replaces operators, and writes it out.
|
||||
Status ConvertConstantsToImmutable(const string& in_graph_filename,
|
||||
const string& out_graph_filename,
|
||||
int min_conversion_size_bytes) {
|
||||
Env* default_env = Env::Default();
|
||||
GraphDef graph_def;
|
||||
const auto load_graph_status =
|
||||
ReadBinaryProto(default_env, in_graph_filename, &graph_def);
|
||||
if (!load_graph_status.ok()) {
|
||||
return tensorflow::errors::NotFound("Failed to load graph at '",
|
||||
in_graph_filename, "' : ",
|
||||
load_graph_status.error_message());
|
||||
}
|
||||
|
||||
NodeConverter node_converter;
|
||||
|
||||
// Create output writer.
|
||||
MemmappedFileSystemWriter writer;
|
||||
TF_RETURN_IF_ERROR(writer.InitializeToFile(default_env, out_graph_filename));
|
||||
|
||||
// Iterate over graph nodes, looking for Const and replacing it with
|
||||
// ImmutableConst.
|
||||
int convert_counter = 0;
|
||||
for (int i = 0; i < graph_def.node_size(); ++i) {
|
||||
const NodeDef& node = graph_def.node(i);
|
||||
if (node.op() == "Const") {
|
||||
// Try to convert to ImmutableConst
|
||||
TF_RETURN_IF_ERROR(node_converter.ConvertConstantsToImmutable(
|
||||
graph_def.mutable_node(i), &writer, &convert_counter,
|
||||
min_conversion_size_bytes));
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer.SaveProtobuf(
|
||||
graph_def, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef));
|
||||
TF_RETURN_IF_ERROR(writer.FlushAndClose());
|
||||
LOG(INFO) << "Converted " << convert_counter << " nodes";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,34 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Converts a "frozen" inference graph (output from the freeze_graph utility)
|
||||
// into a format in which large Const ops are converted to ImmutableConst ops
|
||||
// which are memmapped when the graph is executed by TensorFlow.
|
||||
Status ConvertConstantsToImmutable(const string& in_graph_filename,
|
||||
const string& out_graph_filename,
|
||||
int min_conversion_size_bytes);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
|
@ -0,0 +1,84 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
|
||||
const string dir = testing::TmpDir();
|
||||
const string filename_pb = io::JoinPath(dir, "graphdef.pb");
|
||||
|
||||
// Create a simple graph and write it to filename_pb.
|
||||
constexpr int kTensorWidth = 4000;
|
||||
constexpr int kTensorHeight = 100;
|
||||
const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight});
|
||||
const TensorShape kTestTensorShapeT({kTensorHeight, kTensorWidth});
|
||||
|
||||
Tensor test_tensor1(DT_FLOAT, kTestTensorShape);
|
||||
test::FillFn<float>(&test_tensor1, [](int) -> float { return 2.0; });
|
||||
|
||||
Tensor test_tensor2(DT_FLOAT, kTestTensorShapeT);
|
||||
test::FillFn<float>(&test_tensor2, [](int) -> float { return 3.0; });
|
||||
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* node1 = ops::Const(test_tensor1, b.opts());
|
||||
Node* node2 = ops::Const(test_tensor2, b.opts());
|
||||
const string result_name = ops::MatMul(node1, node2, b.opts())->name();
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(b.ToGraphDef(&graph_def));
|
||||
string graph_def_serialized;
|
||||
graph_def.SerializeToString(&graph_def_serialized);
|
||||
TF_ASSERT_OK(
|
||||
WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized));
|
||||
|
||||
const string filename_mmap = io::JoinPath(dir, "graphdef.mmap");
|
||||
TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 10000));
|
||||
|
||||
// Create and initialize MemmappedEnv from the converted file.
|
||||
MemmappedEnv memmapped_env(Env::Default());
|
||||
TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap));
|
||||
|
||||
// Load the graph and run calculations.
|
||||
SessionOptions session_options;
|
||||
session_options.env = &memmapped_env;
|
||||
std::unique_ptr<Session> session(NewSession(session_options));
|
||||
ASSERT_TRUE(session != nullptr) << "Failed to create session";
|
||||
GraphDef loaded_graph_def;
|
||||
TF_ASSERT_OK(ReadBinaryProto(
|
||||
&memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||
&loaded_graph_def));
|
||||
|
||||
TF_ASSERT_OK(session->Create(loaded_graph_def)) << "Can't create test graph";
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(session->Run({}, {result_name + ":0"}, {}, &outputs));
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
EXPECT_EQ(outputs.front().flat<float>()(0), 2.0f * 3.0f * kTensorHeight);
|
||||
EXPECT_EQ(outputs.front().flat<float>()(1), 2.0f * 3.0f * kTensorHeight);
|
||||
EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f * kTensorHeight);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -261,6 +261,8 @@ tf_cuda_library(
|
||||
"util/device_name_utils.h",
|
||||
"util/events_writer.h",
|
||||
"util/guarded_philox_random.h",
|
||||
"util/memmapped_file_system.h",
|
||||
"util/memmapped_file_system_writer.h",
|
||||
"util/mirror_pad_mode.h",
|
||||
"util/padding.h",
|
||||
"util/port.h",
|
||||
|
@ -43,6 +43,8 @@ string AllocatorStats::DebugString() const {
|
||||
this->num_allocs, this->max_alloc_size);
|
||||
}
|
||||
|
||||
constexpr size_t Allocator::kAllocatorAlignment;
|
||||
|
||||
Allocator::~Allocator() {}
|
||||
|
||||
// If true, cpu allocator collects more stats.
|
||||
|
@ -66,6 +66,9 @@ struct AllocatorStats {
|
||||
// device memory.
|
||||
class Allocator {
|
||||
public:
|
||||
// Align to 32 byte boundary.
|
||||
static constexpr size_t kAllocatorAlignment = 32;
|
||||
|
||||
virtual ~Allocator();
|
||||
|
||||
// Return a string identifying this allocator
|
||||
@ -112,8 +115,8 @@ class Allocator {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void* p = AllocateRaw(32 /* align to 32 byte boundary */,
|
||||
sizeof(T) * num_elements, allocation_attr);
|
||||
void* p = AllocateRaw(kAllocatorAlignment, sizeof(T) * num_elements,
|
||||
allocation_attr);
|
||||
T* typed_p = reinterpret_cast<T*>(p);
|
||||
if (typed_p) RunCtor<T>(typed_p, num_elements);
|
||||
return typed_p;
|
||||
@ -192,11 +195,10 @@ class Allocator {
|
||||
// without running their default ctors and dtors.
|
||||
template <typename T>
|
||||
struct is_simple {
|
||||
static const bool value = std::is_trivial<T>::value ||
|
||||
std::is_same<T, Eigen::half>::value ||
|
||||
std::is_same<T, complex64>::value ||
|
||||
std::is_same<T, complex128>::value ||
|
||||
is_quantized<T>::value;
|
||||
static constexpr bool value =
|
||||
std::is_trivial<T>::value || std::is_same<T, Eigen::half>::value ||
|
||||
std::is_same<T, complex64>::value ||
|
||||
std::is_same<T, complex128>::value || is_quantized<T>::value;
|
||||
};
|
||||
|
||||
// Fills in 'stats' with statistics collected by this allocator.
|
||||
|
@ -20,11 +20,12 @@ namespace tensorflow {
|
||||
ImmutableConstantOp::ImmutableConstantOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
::tensorflow::DataType dtype;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype));
|
||||
OP_REQUIRES_OK(context, context->GetAttr(kDTypeAttr, &dtype));
|
||||
::tensorflow::TensorShape shape;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape));
|
||||
OP_REQUIRES_OK(context, context->GetAttr(kShapeAttr, &shape));
|
||||
string region_name;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("memory_region_name", ®ion_name));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr(kMemoryRegionNameAttr, ®ion_name));
|
||||
OP_REQUIRES_OK(context,
|
||||
allocator_.InitWithMemoryRegion(region_name, context->env()));
|
||||
tensor_ = ::tensorflow::Tensor(&allocator_, dtype, shape);
|
||||
@ -90,6 +91,10 @@ void ImmutableConstantOp::ReadOnlyMemoryRegionAllocator::DeallocateRaw(
|
||||
}
|
||||
}
|
||||
|
||||
constexpr char ImmutableConstantOp::kDTypeAttr[];
|
||||
constexpr char ImmutableConstantOp::kShapeAttr[];
|
||||
constexpr char ImmutableConstantOp::kMemoryRegionNameAttr[];
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ImmutableConst").Device(DEVICE_CPU),
|
||||
ImmutableConstantOp);
|
||||
} // namespace tensorflow
|
||||
|
@ -32,6 +32,11 @@ class ImmutableConstantOp : public OpKernel {
|
||||
bool IsExpensive() override { return false; }
|
||||
~ImmutableConstantOp() override;
|
||||
|
||||
// Names of attributes that are used by this op
|
||||
static constexpr char kDTypeAttr[] = "dtype";
|
||||
static constexpr char kShapeAttr[] = "shape";
|
||||
static constexpr char kMemoryRegionNameAttr[] = "memory_region_name";
|
||||
|
||||
private:
|
||||
class ReadOnlyMemoryRegionAllocator : public ::tensorflow::Allocator {
|
||||
public:
|
||||
|
@ -23,7 +23,7 @@ namespace tensorflow {
|
||||
|
||||
class FileSystemRegistryImpl : public FileSystemRegistry {
|
||||
public:
|
||||
void Register(const string& scheme, Factory factory) override;
|
||||
Status Register(const string& scheme, Factory factory) override;
|
||||
FileSystem* Lookup(const string& scheme) override;
|
||||
Status GetRegisteredFileSystemSchemes(std::vector<string>* schemes) override;
|
||||
|
||||
@ -33,13 +33,15 @@ class FileSystemRegistryImpl : public FileSystemRegistry {
|
||||
GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
void FileSystemRegistryImpl::Register(const string& scheme,
|
||||
FileSystemRegistry::Factory factory) {
|
||||
Status FileSystemRegistryImpl::Register(const string& scheme,
|
||||
FileSystemRegistry::Factory factory) {
|
||||
mutex_lock lock(mu_);
|
||||
QCHECK(
|
||||
registry_.emplace(string(scheme), std::unique_ptr<FileSystem>(factory()))
|
||||
.second)
|
||||
<< "File factory for " << scheme << " already registered";
|
||||
if (!registry_.emplace(string(scheme), std::unique_ptr<FileSystem>(factory()))
|
||||
.second) {
|
||||
return errors::AlreadyExists("File factory for ", scheme,
|
||||
" already registered");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
FileSystem* FileSystemRegistryImpl::Lookup(const string& scheme) {
|
||||
@ -77,9 +79,9 @@ Status Env::GetRegisteredFileSystemSchemes(std::vector<string>* schemes) {
|
||||
return file_system_registry_->GetRegisteredFileSystemSchemes(schemes);
|
||||
}
|
||||
|
||||
void Env::RegisterFileSystem(const string& scheme,
|
||||
FileSystemRegistry::Factory factory) {
|
||||
file_system_registry_->Register(scheme, factory);
|
||||
Status Env::RegisterFileSystem(const string& scheme,
|
||||
FileSystemRegistry::Factory factory) {
|
||||
return file_system_registry_->Register(scheme, factory);
|
||||
}
|
||||
|
||||
Status Env::NewRandomAccessFile(const string& fname,
|
||||
|
@ -68,8 +68,8 @@ class Env {
|
||||
virtual Status GetRegisteredFileSystemSchemes(std::vector<string>* schemes);
|
||||
|
||||
// \brief Register a file system for a scheme.
|
||||
virtual void RegisterFileSystem(const string& scheme,
|
||||
FileSystemRegistry::Factory factory);
|
||||
virtual Status RegisterFileSystem(const string& scheme,
|
||||
FileSystemRegistry::Factory factory);
|
||||
|
||||
/// \brief Creates a brand new random access read-only file with the
|
||||
/// specified name.
|
||||
@ -236,9 +236,9 @@ class EnvWrapper : public Env {
|
||||
return target_->GetRegisteredFileSystemSchemes(schemes);
|
||||
}
|
||||
|
||||
void RegisterFileSystem(const string& scheme,
|
||||
FileSystemRegistry::Factory factory) override {
|
||||
target_->RegisterFileSystem(scheme, factory);
|
||||
Status RegisterFileSystem(const string& scheme,
|
||||
FileSystemRegistry::Factory factory) override {
|
||||
return target_->RegisterFileSystem(scheme, factory);
|
||||
}
|
||||
|
||||
uint64 NowMicros() override { return target_->NowMicros(); }
|
||||
|
@ -34,7 +34,7 @@ class RandomAccessFile;
|
||||
class ReadOnlyMemoryRegion;
|
||||
class WritableFile;
|
||||
|
||||
/// An generic interface for accessing a file system.
|
||||
/// A generic interface for accessing a file system.
|
||||
class FileSystem {
|
||||
public:
|
||||
FileSystem() {}
|
||||
@ -202,7 +202,7 @@ class FileSystemRegistry {
|
||||
typedef std::function<FileSystem*()> Factory;
|
||||
|
||||
virtual ~FileSystemRegistry();
|
||||
virtual void Register(const string& scheme, Factory factory) = 0;
|
||||
virtual Status Register(const string& scheme, Factory factory) = 0;
|
||||
virtual FileSystem* Lookup(const string& scheme) = 0;
|
||||
virtual Status GetRegisteredFileSystemSchemes(
|
||||
std::vector<string>* schemes) = 0;
|
||||
|
281
tensorflow/core/util/memmapped_file_system.cc
Normal file
281
tensorflow/core/util/memmapped_file_system.cc
Normal file
@ -0,0 +1,281 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
uint64 DecodeUint64LittleEndian(const uint8* buffer) {
|
||||
uint64 result = 0;
|
||||
for (int i = 0; i < static_cast<int>(sizeof(uint64)); ++i) {
|
||||
result |= buffer[i] << (8 * i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
class ReadOnlyMemoryRegionFromMemmapped : public ReadOnlyMemoryRegion {
|
||||
public:
|
||||
ReadOnlyMemoryRegionFromMemmapped(const void* data, uint64 length)
|
||||
: data_(data), length_(length) {}
|
||||
~ReadOnlyMemoryRegionFromMemmapped() override = default;
|
||||
const void* data() override { return data_; }
|
||||
uint64 length() override { return length_; }
|
||||
|
||||
private:
|
||||
const void* const data_;
|
||||
const uint64 length_;
|
||||
// intentionally copyable
|
||||
};
|
||||
|
||||
class RandomAccessFileFromMemmapped : public RandomAccessFile {
|
||||
public:
|
||||
RandomAccessFileFromMemmapped(const void* data, uint64 length)
|
||||
: data_(data), length_(length) {}
|
||||
|
||||
~RandomAccessFileFromMemmapped() override = default;
|
||||
|
||||
Status Read(uint64 offset, size_t to_read, StringPiece* result,
|
||||
char* scratch) const override {
|
||||
if (offset >= length_) {
|
||||
result->set(scratch, 0);
|
||||
return Status(error::OUT_OF_RANGE, "Read after file end");
|
||||
}
|
||||
const uint64 region_left =
|
||||
std::min(length_ - offset, static_cast<uint64>(to_read));
|
||||
result->set(reinterpret_cast<const uint8*>(data_) + offset, region_left);
|
||||
return (region_left == to_read)
|
||||
? Status::OK()
|
||||
: Status(error::OUT_OF_RANGE, "Read less bytes than requested");
|
||||
}
|
||||
|
||||
private:
|
||||
const void* const data_;
|
||||
const uint64 length_;
|
||||
// intentionally copyable
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
MemmappedFileSystem::MemmappedFileSystem() {}
|
||||
|
||||
bool MemmappedFileSystem::FileExists(const string& fname) {
|
||||
if (!mapped_memory_) {
|
||||
return false;
|
||||
}
|
||||
const auto dir_element = directory_.find(fname);
|
||||
return dir_element != directory_.end();
|
||||
}
|
||||
|
||||
Status MemmappedFileSystem::NewRandomAccessFile(const string& filename,
|
||||
RandomAccessFile** result) {
|
||||
if (!mapped_memory_) {
|
||||
return errors::FailedPrecondition("MemmappedEnv is not initialized");
|
||||
}
|
||||
const auto dir_element = directory_.find(filename);
|
||||
if (dir_element == directory_.end()) {
|
||||
return errors::NotFound("Region ", filename, " is not found");
|
||||
}
|
||||
*result = new RandomAccessFileFromMemmapped(
|
||||
GetMemoryWithOffset(dir_element->second.offset),
|
||||
dir_element->second.length);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MemmappedFileSystem::NewReadOnlyMemoryRegionFromFile(
|
||||
const string& filename, ReadOnlyMemoryRegion** result) {
|
||||
if (!mapped_memory_) {
|
||||
return errors::FailedPrecondition("MemmappedEnv is not initialized");
|
||||
}
|
||||
const auto dir_element = directory_.find(filename);
|
||||
if (dir_element == directory_.end()) {
|
||||
return errors::NotFound("Region ", filename, " is not found");
|
||||
}
|
||||
*result = new ReadOnlyMemoryRegionFromMemmapped(
|
||||
GetMemoryWithOffset(dir_element->second.offset),
|
||||
dir_element->second.length);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MemmappedFileSystem::GetFileSize(const string& filename, uint64* size) {
|
||||
if (!mapped_memory_) {
|
||||
return errors::FailedPrecondition("MemmappedEnv is not initialized");
|
||||
}
|
||||
const auto dir_element = directory_.find(filename);
|
||||
if (dir_element == directory_.end()) {
|
||||
return errors::NotFound("Region ", filename, " is not found");
|
||||
}
|
||||
*size = dir_element->second.length;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MemmappedFileSystem::NewWritableFile(const string& filename,
|
||||
WritableFile** wf) {
|
||||
return errors::Unimplemented("memmapped format doesn't support writing");
|
||||
}
|
||||
|
||||
Status MemmappedFileSystem::NewAppendableFile(const string& filename,
|
||||
WritableFile** result) {
|
||||
return errors::Unimplemented("memmapped format doesn't support writing");
|
||||
}
|
||||
|
||||
Status MemmappedFileSystem::GetChildren(const string& filename,
|
||||
std::vector<string>* strings) {
|
||||
return errors::Unimplemented("memmapped format doesn't support GetChildren");
|
||||
}
|
||||
|
||||
Status MemmappedFileSystem::DeleteFile(const string& filename) {
|
||||
return errors::Unimplemented("memmapped format doesn't support DeleteFile");
|
||||
}
|
||||
|
||||
Status MemmappedFileSystem::CreateDir(const string& dirname) {
|
||||
return errors::Unimplemented("memmapped format doesn't support CreateDir");
|
||||
}
|
||||
|
||||
Status MemmappedFileSystem::DeleteDir(const string& dirname) {
|
||||
return errors::Unimplemented("memmapped format doesn't support DeleteDir");
|
||||
}
|
||||
|
||||
Status MemmappedFileSystem::RenameFile(const string& filename_from,
|
||||
const string& filename_to) {
|
||||
return errors::Unimplemented("memmapped format doesn't support RenameFile");
|
||||
}
|
||||
|
||||
const void* MemmappedFileSystem::GetMemoryWithOffset(uint64 offset) const {
|
||||
return reinterpret_cast<const uint8*>(mapped_memory_->data()) + offset;
|
||||
}
|
||||
|
||||
constexpr char MemmappedFileSystem::kMemmappedPackagePrefix[];
|
||||
constexpr char MemmappedFileSystem::kMemmappedPackageDefaultGraphDef[];
|
||||
|
||||
Status MemmappedFileSystem::InitializeFromFile(Env* env,
|
||||
const string& filename) {
|
||||
ReadOnlyMemoryRegion* region;
|
||||
TF_RETURN_IF_ERROR(env->NewReadOnlyMemoryRegionFromFile(filename, ®ion));
|
||||
mapped_memory_.reset(region);
|
||||
directory_.clear();
|
||||
if (mapped_memory_->length() <= sizeof(uint64)) {
|
||||
return errors::DataLoss("Corrupted memmapped model file: ", filename,
|
||||
" Invalid package size");
|
||||
}
|
||||
const auto memory_start =
|
||||
reinterpret_cast<const uint8*>(mapped_memory_->data());
|
||||
const uint64 directory_offset = DecodeUint64LittleEndian(
|
||||
memory_start + mapped_memory_->length() - sizeof(uint64));
|
||||
if (directory_offset > mapped_memory_->length() - sizeof(uint64)) {
|
||||
return errors::DataLoss("Corrupted memmapped model file: ", filename,
|
||||
" Invalid directory offset");
|
||||
}
|
||||
MemmappedFileSystemDirectory proto_directory;
|
||||
if (!ParseProtoUnlimited(
|
||||
&proto_directory, memory_start + directory_offset,
|
||||
mapped_memory_->length() - directory_offset - sizeof(uint64))) {
|
||||
return errors::DataLoss("Corrupted memmapped model file: ", filename,
|
||||
" Can't parse its internal directory");
|
||||
}
|
||||
|
||||
// Iterating in reverse order to get lengths of elements;
|
||||
uint64 prev_element_offset = directory_offset;
|
||||
for (auto element_iter = proto_directory.element().rbegin();
|
||||
element_iter != proto_directory.element().rend(); ++element_iter) {
|
||||
// Check that the element offset is in the right range.
|
||||
if (element_iter->offset() >= prev_element_offset) {
|
||||
return errors::DataLoss("Corrupted memmapped model file: ", filename,
|
||||
" Invalid offset of internal component");
|
||||
}
|
||||
if (!directory_
|
||||
.insert(std::make_pair(
|
||||
element_iter->name(),
|
||||
FileRegion(element_iter->offset(),
|
||||
prev_element_offset - element_iter->offset())))
|
||||
.second) {
|
||||
return errors::DataLoss("Corrupted memmapped model file: ", filename,
|
||||
" Duplicate name of internal component ",
|
||||
element_iter->name());
|
||||
}
|
||||
prev_element_offset = element_iter->offset();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool MemmappedFileSystem::IsMemmappedPackageFilename(const string& filename) {
|
||||
return StringPiece(filename).starts_with(kMemmappedPackagePrefix);
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool IsValidRegionChar(char c) {
|
||||
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
|
||||
(c >= '0' && c <= '9') || c == '_' || c == '.';
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool MemmappedFileSystem::IsWellFormedMemmappedPackageFilename(
|
||||
const string& filename) {
|
||||
if (!IsMemmappedPackageFilename(filename)) {
|
||||
return false;
|
||||
}
|
||||
for (char c :
|
||||
filename.substr(strlen(kMemmappedPackagePrefix),
|
||||
filename.length() - strlen(kMemmappedPackagePrefix))) {
|
||||
if (!IsValidRegionChar(c)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MemmappedEnv::MemmappedEnv(Env* env) : EnvWrapper(env) {}
|
||||
|
||||
Status MemmappedEnv::GetFileSystemForFile(const string& fname,
|
||||
FileSystem** result) {
|
||||
if (MemmappedFileSystem::IsMemmappedPackageFilename(fname)) {
|
||||
if (!memmapped_file_system_) {
|
||||
return errors::FailedPrecondition(
|
||||
"MemmappedEnv is not initialized from a file.");
|
||||
}
|
||||
*result = memmapped_file_system_.get();
|
||||
return Status::OK();
|
||||
}
|
||||
return EnvWrapper::GetFileSystemForFile(fname, result);
|
||||
}
|
||||
|
||||
Status MemmappedEnv::GetRegisteredFileSystemSchemes(
|
||||
std::vector<string>* schemes) {
|
||||
const auto status = EnvWrapper::GetRegisteredFileSystemSchemes(schemes);
|
||||
if (status.ok()) {
|
||||
schemes->emplace_back(MemmappedFileSystem::kMemmappedPackagePrefix);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status MemmappedEnv::InitializeFromFile(const string& package_filename) {
|
||||
std::unique_ptr<MemmappedFileSystem> file_system_ptr(new MemmappedFileSystem);
|
||||
const auto status =
|
||||
file_system_ptr->InitializeFromFile(target(), package_filename);
|
||||
if (status.ok()) {
|
||||
memmapped_file_system_ = std::move(file_system_ptr);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
121
tensorflow/core/util/memmapped_file_system.h
Normal file
121
tensorflow/core/util/memmapped_file_system.h
Normal file
@ -0,0 +1,121 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_H_
|
||||
#define TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A file system that uses a graph saved in memmapped format by
|
||||
// MemmappedEnvWriter as a file system.
|
||||
//
|
||||
// The format supports saved tensors and protos. Tensors are saved at aligned
|
||||
// offsets.
|
||||
//
|
||||
// Format specification:
|
||||
// - last 8 bytes of a package is encoded offset to the directory. The encoding
|
||||
// is always little endian, independently from the platform, done by functions
|
||||
// EncodeUint64LittleEndian/DecodeUint64LittleEndian
|
||||
// - the directory starts from the encoded offset and is saved proto
|
||||
// MemmappedFileSystemDirectory with names and offsets to the regions.
|
||||
// - at the offsets in the directory the file regions are stored. Tensor regions
|
||||
// are aligned such way that when the package mapped to RAM they have the right
|
||||
// offset to be used by ImmutableConst operator.
|
||||
//
|
||||
// Region naming:
|
||||
// Region naming is up to the application, all of them starts from
|
||||
// kMemmappedPackagePrefix. The default graph usually has name
|
||||
// kMemmappedPackageDefaultGraphDef; for more details see the conversion
|
||||
// utility
|
||||
// third_party/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc
|
||||
//
|
||||
// A "frozen" GraphDef can be converted into this format using
|
||||
// tensorflow/contrib/util/convert_graphdef_memmapped_format
|
||||
class MemmappedFileSystem : public FileSystem {
|
||||
public:
|
||||
// Memmapped regions use this prefix to distinguish from
|
||||
// the filesystem.
|
||||
static constexpr char kMemmappedPackagePrefix[] = "memmapped_package://";
|
||||
// The default graphdef in the package.
|
||||
static constexpr char kMemmappedPackageDefaultGraphDef[] =
|
||||
"memmapped_package://.";
|
||||
|
||||
MemmappedFileSystem();
|
||||
~MemmappedFileSystem() override = default;
|
||||
bool FileExists(const string& fname) override;
|
||||
Status NewRandomAccessFile(const string& filename,
|
||||
RandomAccessFile** result) override;
|
||||
Status NewReadOnlyMemoryRegionFromFile(
|
||||
const string& filename, ReadOnlyMemoryRegion** result) override;
|
||||
|
||||
// All these functions return Unimplemented error, the memmapped storage is
|
||||
// read only.
|
||||
Status NewWritableFile(const string& fname, WritableFile** result) override;
|
||||
Status NewAppendableFile(const string& fname, WritableFile** result) override;
|
||||
Status GetChildren(const string& dir, std::vector<string>* r) override;
|
||||
Status DeleteFile(const string& f) override;
|
||||
Status CreateDir(const string& d) override;
|
||||
Status DeleteDir(const string& d) override;
|
||||
Status GetFileSize(const string& f, uint64* s) override;
|
||||
Status RenameFile(const string& s, const string& t) override;
|
||||
|
||||
// Initializes filesystem from a file in memmapped format.
|
||||
Status InitializeFromFile(Env* env, const string& filename);
|
||||
|
||||
// Checks if the filename has a correct prefix.
|
||||
static bool IsMemmappedPackageFilename(const string& filename);
|
||||
|
||||
static bool IsWellFormedMemmappedPackageFilename(const string& filename);
|
||||
|
||||
private:
|
||||
struct FileRegion {
|
||||
FileRegion(uint64 o, uint64 l) : offset(o), length(l) {}
|
||||
|
||||
uint64 offset; // Offset from the beginning of the file.
|
||||
uint64 length; // Length of the region.
|
||||
};
|
||||
|
||||
using DirectoryType = std::unordered_map<string, FileRegion>;
|
||||
|
||||
const void* GetMemoryWithOffset(uint64 offset) const;
|
||||
|
||||
std::unique_ptr<ReadOnlyMemoryRegion> mapped_memory_;
|
||||
DirectoryType directory_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MemmappedFileSystem);
|
||||
};
|
||||
|
||||
class MemmappedEnv : public EnvWrapper {
|
||||
public:
|
||||
explicit MemmappedEnv(Env* env);
|
||||
~MemmappedEnv() override = default;
|
||||
Status GetFileSystemForFile(const string& fname,
|
||||
FileSystem** result) override;
|
||||
Status GetRegisteredFileSystemSchemes(std::vector<string>* schemes) override;
|
||||
Status InitializeFromFile(const string& filename);
|
||||
|
||||
protected:
|
||||
std::unique_ptr<MemmappedFileSystem> memmapped_file_system_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_H_
|
29
tensorflow/core/util/memmapped_file_system.proto
Normal file
29
tensorflow/core/util/memmapped_file_system.proto
Normal file
@ -0,0 +1,29 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
// option cc_enable_arenas = true;
|
||||
|
||||
// A message that describes one region of memmapped file.
|
||||
message MemmappedFileSystemDirectoryElement {
|
||||
uint64 offset = 1;
|
||||
string name = 2;
|
||||
}
|
||||
|
||||
// A directory of regions in a memmapped file.
|
||||
message MemmappedFileSystemDirectory {
|
||||
repeated MemmappedFileSystemDirectoryElement element = 1;
|
||||
}
|
150
tensorflow/core/util/memmapped_file_system_test.cc
Normal file
150
tensorflow/core/util/memmapped_file_system_test.cc
Normal file
@ -0,0 +1,150 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system_writer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Names of files in memmapped environment.
|
||||
constexpr char kTensor1FileName[] = "memmapped_package://t1";
|
||||
constexpr char kTensor2FileName[] = "memmapped_package://t2";
|
||||
constexpr char kProtoFileName[] = "memmapped_package://b";
|
||||
constexpr int kTestGraphDefVersion = 666;
|
||||
|
||||
Status CreateMemmappedFileSystemFile(const string& filename, bool corrupted,
|
||||
Tensor* test_tensor) {
|
||||
Env* env = Env::Default();
|
||||
MemmappedFileSystemWriter writer;
|
||||
TF_RETURN_IF_ERROR(writer.InitializeToFile(env, filename));
|
||||
|
||||
// Try to write a tensor and proto.
|
||||
test::FillFn<float>(test_tensor,
|
||||
[](int i) { return static_cast<float>(i * i); });
|
||||
|
||||
TF_RETURN_IF_ERROR(writer.SaveTensor(*test_tensor, kTensor1FileName));
|
||||
|
||||
// Create a proto with some fields.
|
||||
GraphDef graph_def;
|
||||
graph_def.set_version(kTestGraphDefVersion);
|
||||
TF_RETURN_IF_ERROR(writer.SaveProtobuf(graph_def, kProtoFileName));
|
||||
|
||||
// Save a tensor after the proto to check that alignment works.
|
||||
test::FillFn<float>(test_tensor,
|
||||
[](int i) { return static_cast<float>(i * i * i); });
|
||||
TF_RETURN_IF_ERROR(writer.SaveTensor(*test_tensor, kTensor2FileName));
|
||||
|
||||
if (!corrupted) {
|
||||
// Flush and close the file.
|
||||
TF_RETURN_IF_ERROR(writer.FlushAndClose());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST(MemmappedFileSystemTest, SimpleTest) {
|
||||
const TensorShape test_tensor_shape = {10, 200};
|
||||
Tensor test_tensor(DT_FLOAT, test_tensor_shape);
|
||||
const string dir = testing::TmpDir();
|
||||
const string filename = io::JoinPath(dir, "memmapped_env_test");
|
||||
TF_ASSERT_OK(CreateMemmappedFileSystemFile(filename, false, &test_tensor));
|
||||
|
||||
// Check that we can memmap the created file.
|
||||
MemmappedEnv memmapped_env(Env::Default());
|
||||
TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename));
|
||||
// Try to load a proto from the file.
|
||||
GraphDef test_graph_def;
|
||||
TF_EXPECT_OK(
|
||||
ReadBinaryProto(&memmapped_env, kProtoFileName, &test_graph_def));
|
||||
EXPECT_EQ(kTestGraphDefVersion, test_graph_def.version());
|
||||
// Check that we can correctly get a tensor memory.
|
||||
ReadOnlyMemoryRegion* memory_region;
|
||||
TF_ASSERT_OK(memmapped_env.NewReadOnlyMemoryRegionFromFile(kTensor2FileName,
|
||||
&memory_region));
|
||||
std::unique_ptr<ReadOnlyMemoryRegion> mem_region_ptr(memory_region);
|
||||
// The memory region can be bigger but not less than Tensor size.
|
||||
ASSERT_GE(memory_region->length(), test_tensor.TotalBytes());
|
||||
EXPECT_EQ(test_tensor.tensor_data(),
|
||||
StringPiece(static_cast<const char*>(memory_region->data()),
|
||||
test_tensor.TotalBytes()));
|
||||
// Check that GetFileSize works.
|
||||
uint64 file_size = 0;
|
||||
TF_ASSERT_OK(memmapped_env.GetFileSize(kTensor2FileName, &file_size));
|
||||
EXPECT_EQ(test_tensor.TotalBytes(), file_size);
|
||||
|
||||
// Check that if file not found correct error message returned.
|
||||
EXPECT_EQ(
|
||||
error::NOT_FOUND,
|
||||
memmapped_env.NewReadOnlyMemoryRegionFromFile("bla-bla", &memory_region)
|
||||
.code());
|
||||
|
||||
// Check FileExists.
|
||||
EXPECT_TRUE(memmapped_env.FileExists(kTensor2FileName));
|
||||
EXPECT_FALSE(memmapped_env.FileExists("bla-bla-bla"));
|
||||
}
|
||||
|
||||
TEST(MemmappedFileSystemTest, NotInitalized) {
|
||||
MemmappedEnv memmapped_env(Env::Default());
|
||||
ReadOnlyMemoryRegion* memory_region;
|
||||
EXPECT_EQ(
|
||||
error::FAILED_PRECONDITION,
|
||||
memmapped_env
|
||||
.NewReadOnlyMemoryRegionFromFile(kTensor1FileName, &memory_region)
|
||||
.code());
|
||||
RandomAccessFile* file;
|
||||
EXPECT_EQ(error::FAILED_PRECONDITION,
|
||||
memmapped_env.NewRandomAccessFile(kProtoFileName, &file).code());
|
||||
}
|
||||
|
||||
TEST(MemmappedFileSystemTest, Corrupted) {
|
||||
// Create a corrupted file (it is not closed it properly).
|
||||
const TensorShape test_tensor_shape = {100, 200};
|
||||
Tensor test_tensor(DT_FLOAT, test_tensor_shape);
|
||||
const string dir = testing::TmpDir();
|
||||
const string filename = io::JoinPath(dir, "memmapped_env_corrupted_test");
|
||||
TF_ASSERT_OK(CreateMemmappedFileSystemFile(filename, true, &test_tensor));
|
||||
MemmappedFileSystem memmapped_env;
|
||||
ASSERT_NE(memmapped_env.InitializeFromFile(Env::Default(), filename),
|
||||
Status::OK());
|
||||
}
|
||||
|
||||
TEST(MemmappedFileSystemTest, ProxyToDefault) {
|
||||
MemmappedEnv memmapped_env(Env::Default());
|
||||
const string dir = testing::TmpDir();
|
||||
const string filename = io::JoinPath(dir, "test_file");
|
||||
// Check that we can create write and read ordinary file.
|
||||
WritableFile* writable_file;
|
||||
TF_ASSERT_OK(memmapped_env.NewAppendableFile(filename, &writable_file));
|
||||
std::unique_ptr<WritableFile> writable_file_ptr(writable_file);
|
||||
const string test_string = "bla-bla-bla";
|
||||
TF_ASSERT_OK(writable_file->Append(test_string));
|
||||
TF_ASSERT_OK(writable_file->Close());
|
||||
uint64 file_length = 0;
|
||||
TF_EXPECT_OK(memmapped_env.GetFileSize(filename, &file_length));
|
||||
EXPECT_EQ(test_string.length(), file_length);
|
||||
RandomAccessFile* random_access_file;
|
||||
TF_ASSERT_OK(
|
||||
memmapped_env.NewRandomAccessFile(filename, &random_access_file));
|
||||
delete random_access_file;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
136
tensorflow/core/util/memmapped_file_system_writer.cc
Normal file
136
tensorflow/core/util/memmapped_file_system_writer.cc
Normal file
@ -0,0 +1,136 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/util/memmapped_file_system_writer.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status MemmappedFileSystemWriter::InitializeToFile(Env* env,
|
||||
const string& filename) {
|
||||
WritableFile* writable_file;
|
||||
auto status = env->NewWritableFile(filename, &writable_file);
|
||||
if (status.ok()) {
|
||||
output_file_.reset(writable_file);
|
||||
output_file_offset_ = 0;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status MemmappedFileSystemWriter::SaveTensor(const Tensor& tensor,
|
||||
const string& element_name) {
|
||||
if (!output_file_) {
|
||||
return errors::FailedPrecondition(
|
||||
"MemmappedEnvWritter: saving tensor into not opened file");
|
||||
}
|
||||
if (!MemmappedFileSystem::IsWellFormedMemmappedPackageFilename(
|
||||
element_name)) {
|
||||
return errors::InvalidArgument(
|
||||
"MemmappedEnvWritter: element_name is invalid: must have memmapped ",
|
||||
"package prefix ", MemmappedFileSystem::kMemmappedPackagePrefix,
|
||||
" and include [A-Za-z0-9_.]");
|
||||
}
|
||||
const auto tensor_data = tensor.tensor_data();
|
||||
if (0 == tensor_data.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"MemmappedEnvWritter: saving tensor with 0 size");
|
||||
}
|
||||
// Adds pad for correct alignment after memmapping.
|
||||
TF_RETURN_IF_ERROR(AdjustAlignment(Allocator::kAllocatorAlignment));
|
||||
AddToDirectoryElement(element_name);
|
||||
const auto result = output_file_->Append(tensor_data);
|
||||
if (result.ok()) {
|
||||
output_file_offset_ += tensor_data.size();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Status MemmappedFileSystemWriter::SaveProtobuf(
|
||||
const protobuf::MessageLite& message, const string& element_name) {
|
||||
if (!output_file_) {
|
||||
return errors::FailedPrecondition(
|
||||
"MemmappedEnvWritter: saving protobuf into not opened file");
|
||||
}
|
||||
if (!MemmappedFileSystem::IsWellFormedMemmappedPackageFilename(
|
||||
element_name)) {
|
||||
return errors::InvalidArgument(
|
||||
"MemmappedEnvWritter: element_name is invalid: must have memmapped "
|
||||
"package prefix ",
|
||||
MemmappedFileSystem::kMemmappedPackagePrefix,
|
||||
" and include [A-Za-z0-9_.]");
|
||||
}
|
||||
AddToDirectoryElement(element_name);
|
||||
const string encoded = message.SerializeAsString();
|
||||
const auto res = output_file_->Append(encoded);
|
||||
if (res.ok()) {
|
||||
output_file_offset_ += encoded.size();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
StringPiece EncodeUint64LittleEndian(uint64 val, char* output_buffer) {
|
||||
for (int i = 0; i < sizeof(uint64); ++i) {
|
||||
output_buffer[i] = (val >> i * 8);
|
||||
}
|
||||
return {output_buffer, sizeof(uint64)};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status MemmappedFileSystemWriter::FlushAndClose() {
|
||||
if (!output_file_) {
|
||||
return errors::FailedPrecondition(
|
||||
"MemmappedEnvWritter: flushing into not opened file");
|
||||
}
|
||||
const string dir = directory_.SerializeAsString();
|
||||
TF_RETURN_IF_ERROR(output_file_->Append(dir));
|
||||
|
||||
// Write the directory offset.
|
||||
char buffer[sizeof(uint64)];
|
||||
TF_RETURN_IF_ERROR(output_file_->Append(
|
||||
EncodeUint64LittleEndian(output_file_offset_, buffer)));
|
||||
|
||||
// Flush and close the file.
|
||||
TF_RETURN_IF_ERROR(output_file_->Flush());
|
||||
TF_RETURN_IF_ERROR(output_file_->Close());
|
||||
output_file_.reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MemmappedFileSystemWriter::AdjustAlignment(uint64 alignment) {
|
||||
const uint64 alignment_rest = output_file_offset_ % alignment;
|
||||
const uint64 to_write_for_alignment =
|
||||
(alignment_rest == 0) ? 0 : alignment - (output_file_offset_ % alignment);
|
||||
static constexpr uint64 kFillerBufferSize = 16;
|
||||
const char kFillerBuffer[kFillerBufferSize] = {};
|
||||
for (uint64 rest = to_write_for_alignment; rest > 0;) {
|
||||
StringPiece sp(kFillerBuffer, std::min(rest, kFillerBufferSize));
|
||||
TF_RETURN_IF_ERROR(output_file_->Append(sp));
|
||||
rest -= sp.size();
|
||||
output_file_offset_ += sp.size();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MemmappedFileSystemWriter::AddToDirectoryElement(const string& name) {
|
||||
MemmappedFileSystemDirectoryElement* new_directory_element =
|
||||
directory_.add_element();
|
||||
new_directory_element->set_offset(output_file_offset_);
|
||||
new_directory_element->set_name(name);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
53
tensorflow/core/util/memmapped_file_system_writer.h
Normal file
53
tensorflow/core/util/memmapped_file_system_writer.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_WRITER_H_
|
||||
#define TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_WRITER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A class for saving into the memmapped format that can be read by
|
||||
// MemmappedFileSystem.
|
||||
class MemmappedFileSystemWriter {
|
||||
public:
|
||||
MemmappedFileSystemWriter() = default;
|
||||
~MemmappedFileSystemWriter() = default;
|
||||
Status InitializeToFile(Env* env, const string& filename);
|
||||
Status SaveTensor(const Tensor& tensor, const string& element_name);
|
||||
Status SaveProtobuf(const protobuf::MessageLite& message,
|
||||
const string& element_name);
|
||||
// Writes out the directory of regions and closes the output file.
|
||||
Status FlushAndClose();
|
||||
|
||||
private:
|
||||
Status AdjustAlignment(uint64 alignment);
|
||||
void AddToDirectoryElement(const string& element_name);
|
||||
MemmappedFileSystemDirectory directory_;
|
||||
// The current offset in the file, to support alignment.
|
||||
uint64 output_file_offset_ = 0;
|
||||
std::unique_ptr<WritableFile> output_file_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MemmappedFileSystemWriter);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_WRITER_H_
|
@ -61,6 +61,7 @@ def tf_android_core_proto_sources_relative():
|
||||
"lib/core/error_codes.proto",
|
||||
"protobuf/config.proto",
|
||||
"protobuf/saver.proto",
|
||||
"util/memmapped_file_system.proto",
|
||||
"util/saved_tensor_slice.proto",
|
||||
"util/test_log.proto",
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user