From 3c280f6fa0e0fcaa3d2cee5d2d8bb7ab3e25319f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Apr 2016 08:08:23 -0800 Subject: [PATCH] Added a format for saving an inference graph that can be memmapped and an utility to convert a freezed graph into this format. Change: 120128412 --- tensorflow/contrib/util/BUILD | 41 +++ .../util/convert_graphdef_memmapped_format.cc | 88 ++++++ .../convert_graphdef_memmapped_format_lib.cc | 156 ++++++++++ .../convert_graphdef_memmapped_format_lib.h | 34 +++ .../convert_graphdef_memmapped_format_test.cc | 84 ++++++ tensorflow/core/BUILD | 2 + tensorflow/core/framework/allocator.cc | 2 + tensorflow/core/framework/allocator.h | 16 +- .../core/kernels/immutable_constant_op.cc | 11 +- .../core/kernels/immutable_constant_op.h | 5 + tensorflow/core/platform/env.cc | 22 +- tensorflow/core/platform/env.h | 10 +- tensorflow/core/platform/file_system.h | 4 +- tensorflow/core/util/memmapped_file_system.cc | 281 ++++++++++++++++++ tensorflow/core/util/memmapped_file_system.h | 121 ++++++++ .../core/util/memmapped_file_system.proto | 29 ++ .../core/util/memmapped_file_system_test.cc | 150 ++++++++++ .../core/util/memmapped_file_system_writer.cc | 136 +++++++++ .../core/util/memmapped_file_system_writer.h | 53 ++++ tensorflow/tensorflow.bzl | 1 + 20 files changed, 1219 insertions(+), 27 deletions(-) create mode 100644 tensorflow/contrib/util/convert_graphdef_memmapped_format.cc create mode 100644 tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc create mode 100644 tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h create mode 100644 tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc create mode 100644 tensorflow/core/util/memmapped_file_system.cc create mode 100644 tensorflow/core/util/memmapped_file_system.h create mode 100644 tensorflow/core/util/memmapped_file_system.proto create mode 100644 tensorflow/core/util/memmapped_file_system_test.cc create mode 100644 tensorflow/core/util/memmapped_file_system_writer.cc create mode 100644 tensorflow/core/util/memmapped_file_system_writer.h diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index c0be2b9c140..80495c9b8a1 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -7,6 +7,47 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) +# Convertor of a frozen graph definition into the memmapped format. +cc_library( + name = "convert_graphdef_memmapped_format_lib", + srcs = ["convert_graphdef_memmapped_format_lib.cc"], + hdrs = ["convert_graphdef_memmapped_format_lib.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core/kernels:immutable_constant_op", + ], +) + +cc_binary( + name = "convert_graphdef_memmapped_format", + srcs = ["convert_graphdef_memmapped_format.cc"], + deps = [ + ":convert_graphdef_memmapped_format_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "convert_graphdef_memmapped_format_test", + srcs = ["convert_graphdef_memmapped_format_test.cc"], + deps = [ + ":convert_graphdef_memmapped_format_lib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_binary( name = "inspect_checkpoint", srcs = ["inspect_checkpoint.cc"], diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc new file mode 100644 index 00000000000..811761efd6b --- /dev/null +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc @@ -0,0 +1,88 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility that converts a "frozen" inference graph (output from the +// freeze_graph utility) into a format in which large Const ops are converted to +// ImmutableConst ops which are memmapped when the graph is executed by +// TensorFlow. +// +// tensorflow/contrib/util/convert_graphdef_memmapped_format +// --in_graph=frozen.model --out_graph=memmapped.mmodel +// +// Parameters: +// in_graph - name of a file with a frozen GraphDef proto in binary format +// out_graph - name of the output file, where the graph in memmapped format will +// be saved. +// min_conversion_size_bytes - tensors with fewer than this many bytes of data +// will not be converted to ImmutableConst format, and kept in the graph. + +#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +int ParseFlagsAndConvertGraph(int argc, char* argv[]) { + string in_graph = ""; + string out_graph = ""; + int min_conversion_tensor_size = 10000; + const bool parse_result = ParseFlags( + &argc, argv, + {// input graph + Flag("in_graph", &in_graph), + // output graph + Flag("out_graph", &out_graph), + // constants with tensors that have less than this number elements won't + // be converted into ImmutableConst (be memmapped). + Flag("min_conversion_tensor_size", &min_conversion_tensor_size)}); + // We need to call this to set up global state for TensorFlow. + port::InitMain(argv[0], &argc, &argv); + if (!parse_result) { + LOG(ERROR) << "Error parsing command-line flags."; + return -1; + } + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1]; + return -1; + } + if (in_graph.empty()) { + LOG(ERROR) << "in_graph graph can't be empty"; + return -1; + } + if (out_graph.empty()) { + LOG(ERROR) << "out_graph graph can't be empty"; + return -1; + } + if (min_conversion_tensor_size <= 0) { + LOG(ERROR) << "min_conversion_tensor_size must be > 0"; + return -1; + } + const auto result = ConvertConstantsToImmutable(in_graph, out_graph, + min_conversion_tensor_size); + if (!result.ok()) { + LOG(ERROR) << "Conversion failed " << result.error_message(); + return -1; + } + return 0; +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char* argv[]) { + return tensorflow::ParseFlagsAndConvertGraph(argc, argv); +} diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc new file mode 100644 index 00000000000..7697a7f3d26 --- /dev/null +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc @@ -0,0 +1,156 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h" + +#include +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/kernels/immutable_constant_op.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/memmapped_file_system_writer.h" + +namespace tensorflow { +namespace { +class NodeConverter { + public: + // Converts one node. In-place updates node_def, writes the tensor in + // memmapped + // format, using writer. If the conversion has been done, convert_counter is + // increased. + Status ConvertConstantsToImmutable(NodeDef* node_def, + MemmappedFileSystemWriter* writer, + int* convert_counter, + int min_conversion_size_bytes) { + // Check the size. + const AttrValue& value = node_def->attr().at("value"); + const TensorProto& tensor_proto = value.tensor(); + + // Create copies of tensor datatype and shape, to put into the operator + // after + // the tensor is destroyed. + const DataType tensor_data_type = tensor_proto.dtype(); + const TensorShapeProto tensor_shape = tensor_proto.tensor_shape(); + + // Create Tensor from value and write it in memmapped format. + Tensor parsed(tensor_proto.dtype()); + if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + tensor_proto.DebugString()); + } + if (parsed.TotalBytes() < min_conversion_size_bytes) { + return Status::OK(); + } + + const string memmapped_region_name = + MemmappedFileSystem::kMemmappedPackagePrefix + + ConvertVariableNameToUniqueRegionName(node_def->name()); + + TF_RETURN_IF_ERROR(writer->SaveTensor(parsed, memmapped_region_name)); + + node_def->set_op("ImmutableConst"); + + // Erase all attributes and leave only attributes that can be understood by + // ImmutableConst. + auto* mutable_attr = node_def->mutable_attr(); + mutable_attr->clear(); + + { + AttrValue attr_value; + attr_value.set_type(tensor_data_type); + mutable_attr->insert({ImmutableConstantOp::kDTypeAttr, attr_value}); + } + { + AttrValue attr_value; + *(attr_value.mutable_shape()) = tensor_shape; + mutable_attr->insert({ImmutableConstantOp::kShapeAttr, attr_value}); + } + { + AttrValue attr_value; + attr_value.set_s(memmapped_region_name); + mutable_attr->insert( + {ImmutableConstantOp::kMemoryRegionNameAttr, attr_value}); + } + ++*convert_counter; + return Status::OK(); + } + + private: + string ConvertVariableNameToUniqueRegionName(const string& variable_name) { + string region_name = SanitizeVariableName(variable_name); + while (!used_names_.insert(region_name).second) { + region_name += '_'; + } + return region_name; + } + + static string SanitizeVariableName(const string& variable_name) { + string result; + for (char c : variable_name) { + if ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || c == '_' || c == '.') { + result += c; + } else { + result += '_'; + } + } + return result; + } + std::unordered_set used_names_; +}; + +} // namespace + +// Loads the graph, replaces operators, and writes it out. +Status ConvertConstantsToImmutable(const string& in_graph_filename, + const string& out_graph_filename, + int min_conversion_size_bytes) { + Env* default_env = Env::Default(); + GraphDef graph_def; + const auto load_graph_status = + ReadBinaryProto(default_env, in_graph_filename, &graph_def); + if (!load_graph_status.ok()) { + return tensorflow::errors::NotFound("Failed to load graph at '", + in_graph_filename, "' : ", + load_graph_status.error_message()); + } + + NodeConverter node_converter; + + // Create output writer. + MemmappedFileSystemWriter writer; + TF_RETURN_IF_ERROR(writer.InitializeToFile(default_env, out_graph_filename)); + + // Iterate over graph nodes, looking for Const and replacing it with + // ImmutableConst. + int convert_counter = 0; + for (int i = 0; i < graph_def.node_size(); ++i) { + const NodeDef& node = graph_def.node(i); + if (node.op() == "Const") { + // Try to convert to ImmutableConst + TF_RETURN_IF_ERROR(node_converter.ConvertConstantsToImmutable( + graph_def.mutable_node(i), &writer, &convert_counter, + min_conversion_size_bytes)); + } + } + TF_RETURN_IF_ERROR(writer.SaveProtobuf( + graph_def, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef)); + TF_RETURN_IF_ERROR(writer.FlushAndClose()); + LOG(INFO) << "Converted " << convert_counter << " nodes"; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h new file mode 100644 index 00000000000..e6fd1bb132f --- /dev/null +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h @@ -0,0 +1,34 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ + +#include + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Converts a "frozen" inference graph (output from the freeze_graph utility) +// into a format in which large Const ops are converted to ImmutableConst ops +// which are memmapped when the graph is executed by TensorFlow. +Status ConvertConstantsToImmutable(const string& in_graph_filename, + const string& out_graph_filename, + int min_conversion_size_bytes); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc new file mode 100644 index 00000000000..7710fc38efe --- /dev/null +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/memmapped_file_system.h" + +namespace tensorflow { +namespace { + +TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) { + const string dir = testing::TmpDir(); + const string filename_pb = io::JoinPath(dir, "graphdef.pb"); + + // Create a simple graph and write it to filename_pb. + constexpr int kTensorWidth = 4000; + constexpr int kTensorHeight = 100; + const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight}); + const TensorShape kTestTensorShapeT({kTensorHeight, kTensorWidth}); + + Tensor test_tensor1(DT_FLOAT, kTestTensorShape); + test::FillFn(&test_tensor1, [](int) -> float { return 2.0; }); + + Tensor test_tensor2(DT_FLOAT, kTestTensorShapeT); + test::FillFn(&test_tensor2, [](int) -> float { return 3.0; }); + + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* node1 = ops::Const(test_tensor1, b.opts()); + Node* node2 = ops::Const(test_tensor2, b.opts()); + const string result_name = ops::MatMul(node1, node2, b.opts())->name(); + + GraphDef graph_def; + TF_ASSERT_OK(b.ToGraphDef(&graph_def)); + string graph_def_serialized; + graph_def.SerializeToString(&graph_def_serialized); + TF_ASSERT_OK( + WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized)); + + const string filename_mmap = io::JoinPath(dir, "graphdef.mmap"); + TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 10000)); + + // Create and initialize MemmappedEnv from the converted file. + MemmappedEnv memmapped_env(Env::Default()); + TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap)); + + // Load the graph and run calculations. + SessionOptions session_options; + session_options.env = &memmapped_env; + std::unique_ptr session(NewSession(session_options)); + ASSERT_TRUE(session != nullptr) << "Failed to create session"; + GraphDef loaded_graph_def; + TF_ASSERT_OK(ReadBinaryProto( + &memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, + &loaded_graph_def)); + + TF_ASSERT_OK(session->Create(loaded_graph_def)) << "Can't create test graph"; + std::vector outputs; + TF_ASSERT_OK(session->Run({}, {result_name + ":0"}, {}, &outputs)); + ASSERT_EQ(outputs.size(), 1); + EXPECT_EQ(outputs.front().flat()(0), 2.0f * 3.0f * kTensorHeight); + EXPECT_EQ(outputs.front().flat()(1), 2.0f * 3.0f * kTensorHeight); + EXPECT_EQ(outputs.front().flat()(2), 2.0f * 3.0f * kTensorHeight); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index f553c88b252..080e8514475 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -261,6 +261,8 @@ tf_cuda_library( "util/device_name_utils.h", "util/events_writer.h", "util/guarded_philox_random.h", + "util/memmapped_file_system.h", + "util/memmapped_file_system_writer.h", "util/mirror_pad_mode.h", "util/padding.h", "util/port.h", diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index f11ebd11d32..f5fc55f1e03 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -43,6 +43,8 @@ string AllocatorStats::DebugString() const { this->num_allocs, this->max_alloc_size); } +constexpr size_t Allocator::kAllocatorAlignment; + Allocator::~Allocator() {} // If true, cpu allocator collects more stats. diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 28a42165496..b3614f3fd96 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -66,6 +66,9 @@ struct AllocatorStats { // device memory. class Allocator { public: + // Align to 32 byte boundary. + static constexpr size_t kAllocatorAlignment = 32; + virtual ~Allocator(); // Return a string identifying this allocator @@ -112,8 +115,8 @@ class Allocator { return NULL; } - void* p = AllocateRaw(32 /* align to 32 byte boundary */, - sizeof(T) * num_elements, allocation_attr); + void* p = AllocateRaw(kAllocatorAlignment, sizeof(T) * num_elements, + allocation_attr); T* typed_p = reinterpret_cast(p); if (typed_p) RunCtor(typed_p, num_elements); return typed_p; @@ -192,11 +195,10 @@ class Allocator { // without running their default ctors and dtors. template struct is_simple { - static const bool value = std::is_trivial::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - is_quantized::value; + static constexpr bool value = + std::is_trivial::value || std::is_same::value || + std::is_same::value || + std::is_same::value || is_quantized::value; }; // Fills in 'stats' with statistics collected by this allocator. diff --git a/tensorflow/core/kernels/immutable_constant_op.cc b/tensorflow/core/kernels/immutable_constant_op.cc index 3b8c1b1ca02..22bd88793ed 100644 --- a/tensorflow/core/kernels/immutable_constant_op.cc +++ b/tensorflow/core/kernels/immutable_constant_op.cc @@ -20,11 +20,12 @@ namespace tensorflow { ImmutableConstantOp::ImmutableConstantOp(OpKernelConstruction* context) : OpKernel(context) { ::tensorflow::DataType dtype; - OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype)); + OP_REQUIRES_OK(context, context->GetAttr(kDTypeAttr, &dtype)); ::tensorflow::TensorShape shape; - OP_REQUIRES_OK(context, context->GetAttr("shape", &shape)); + OP_REQUIRES_OK(context, context->GetAttr(kShapeAttr, &shape)); string region_name; - OP_REQUIRES_OK(context, context->GetAttr("memory_region_name", ®ion_name)); + OP_REQUIRES_OK(context, + context->GetAttr(kMemoryRegionNameAttr, ®ion_name)); OP_REQUIRES_OK(context, allocator_.InitWithMemoryRegion(region_name, context->env())); tensor_ = ::tensorflow::Tensor(&allocator_, dtype, shape); @@ -90,6 +91,10 @@ void ImmutableConstantOp::ReadOnlyMemoryRegionAllocator::DeallocateRaw( } } +constexpr char ImmutableConstantOp::kDTypeAttr[]; +constexpr char ImmutableConstantOp::kShapeAttr[]; +constexpr char ImmutableConstantOp::kMemoryRegionNameAttr[]; + REGISTER_KERNEL_BUILDER(Name("ImmutableConst").Device(DEVICE_CPU), ImmutableConstantOp); } // namespace tensorflow diff --git a/tensorflow/core/kernels/immutable_constant_op.h b/tensorflow/core/kernels/immutable_constant_op.h index a731fa4b972..ecfc1a027a4 100644 --- a/tensorflow/core/kernels/immutable_constant_op.h +++ b/tensorflow/core/kernels/immutable_constant_op.h @@ -32,6 +32,11 @@ class ImmutableConstantOp : public OpKernel { bool IsExpensive() override { return false; } ~ImmutableConstantOp() override; + // Names of attributes that are used by this op + static constexpr char kDTypeAttr[] = "dtype"; + static constexpr char kShapeAttr[] = "shape"; + static constexpr char kMemoryRegionNameAttr[] = "memory_region_name"; + private: class ReadOnlyMemoryRegionAllocator : public ::tensorflow::Allocator { public: diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index 714b4511f89..dc84fb9ac33 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -23,7 +23,7 @@ namespace tensorflow { class FileSystemRegistryImpl : public FileSystemRegistry { public: - void Register(const string& scheme, Factory factory) override; + Status Register(const string& scheme, Factory factory) override; FileSystem* Lookup(const string& scheme) override; Status GetRegisteredFileSystemSchemes(std::vector* schemes) override; @@ -33,13 +33,15 @@ class FileSystemRegistryImpl : public FileSystemRegistry { GUARDED_BY(mu_); }; -void FileSystemRegistryImpl::Register(const string& scheme, - FileSystemRegistry::Factory factory) { +Status FileSystemRegistryImpl::Register(const string& scheme, + FileSystemRegistry::Factory factory) { mutex_lock lock(mu_); - QCHECK( - registry_.emplace(string(scheme), std::unique_ptr(factory())) - .second) - << "File factory for " << scheme << " already registered"; + if (!registry_.emplace(string(scheme), std::unique_ptr(factory())) + .second) { + return errors::AlreadyExists("File factory for ", scheme, + " already registered"); + } + return Status::OK(); } FileSystem* FileSystemRegistryImpl::Lookup(const string& scheme) { @@ -77,9 +79,9 @@ Status Env::GetRegisteredFileSystemSchemes(std::vector* schemes) { return file_system_registry_->GetRegisteredFileSystemSchemes(schemes); } -void Env::RegisterFileSystem(const string& scheme, - FileSystemRegistry::Factory factory) { - file_system_registry_->Register(scheme, factory); +Status Env::RegisterFileSystem(const string& scheme, + FileSystemRegistry::Factory factory) { + return file_system_registry_->Register(scheme, factory); } Status Env::NewRandomAccessFile(const string& fname, diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h index 1abe5cd2c0b..dd1beb32465 100644 --- a/tensorflow/core/platform/env.h +++ b/tensorflow/core/platform/env.h @@ -68,8 +68,8 @@ class Env { virtual Status GetRegisteredFileSystemSchemes(std::vector* schemes); // \brief Register a file system for a scheme. - virtual void RegisterFileSystem(const string& scheme, - FileSystemRegistry::Factory factory); + virtual Status RegisterFileSystem(const string& scheme, + FileSystemRegistry::Factory factory); /// \brief Creates a brand new random access read-only file with the /// specified name. @@ -236,9 +236,9 @@ class EnvWrapper : public Env { return target_->GetRegisteredFileSystemSchemes(schemes); } - void RegisterFileSystem(const string& scheme, - FileSystemRegistry::Factory factory) override { - target_->RegisterFileSystem(scheme, factory); + Status RegisterFileSystem(const string& scheme, + FileSystemRegistry::Factory factory) override { + return target_->RegisterFileSystem(scheme, factory); } uint64 NowMicros() override { return target_->NowMicros(); } diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h index fc64ae17b82..8678ef4d3af 100644 --- a/tensorflow/core/platform/file_system.h +++ b/tensorflow/core/platform/file_system.h @@ -34,7 +34,7 @@ class RandomAccessFile; class ReadOnlyMemoryRegion; class WritableFile; -/// An generic interface for accessing a file system. +/// A generic interface for accessing a file system. class FileSystem { public: FileSystem() {} @@ -202,7 +202,7 @@ class FileSystemRegistry { typedef std::function Factory; virtual ~FileSystemRegistry(); - virtual void Register(const string& scheme, Factory factory) = 0; + virtual Status Register(const string& scheme, Factory factory) = 0; virtual FileSystem* Lookup(const string& scheme) = 0; virtual Status GetRegisteredFileSystemSchemes( std::vector* schemes) = 0; diff --git a/tensorflow/core/util/memmapped_file_system.cc b/tensorflow/core/util/memmapped_file_system.cc new file mode 100644 index 00000000000..b0537c8038a --- /dev/null +++ b/tensorflow/core/util/memmapped_file_system.cc @@ -0,0 +1,281 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/util/memmapped_file_system.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/memmapped_file_system.pb.h" + +namespace tensorflow { + +namespace { + +uint64 DecodeUint64LittleEndian(const uint8* buffer) { + uint64 result = 0; + for (int i = 0; i < static_cast(sizeof(uint64)); ++i) { + result |= buffer[i] << (8 * i); + } + return result; +} + +} // namespace + +namespace { + +class ReadOnlyMemoryRegionFromMemmapped : public ReadOnlyMemoryRegion { + public: + ReadOnlyMemoryRegionFromMemmapped(const void* data, uint64 length) + : data_(data), length_(length) {} + ~ReadOnlyMemoryRegionFromMemmapped() override = default; + const void* data() override { return data_; } + uint64 length() override { return length_; } + + private: + const void* const data_; + const uint64 length_; + // intentionally copyable +}; + +class RandomAccessFileFromMemmapped : public RandomAccessFile { + public: + RandomAccessFileFromMemmapped(const void* data, uint64 length) + : data_(data), length_(length) {} + + ~RandomAccessFileFromMemmapped() override = default; + + Status Read(uint64 offset, size_t to_read, StringPiece* result, + char* scratch) const override { + if (offset >= length_) { + result->set(scratch, 0); + return Status(error::OUT_OF_RANGE, "Read after file end"); + } + const uint64 region_left = + std::min(length_ - offset, static_cast(to_read)); + result->set(reinterpret_cast(data_) + offset, region_left); + return (region_left == to_read) + ? Status::OK() + : Status(error::OUT_OF_RANGE, "Read less bytes than requested"); + } + + private: + const void* const data_; + const uint64 length_; + // intentionally copyable +}; + +} // namespace + +MemmappedFileSystem::MemmappedFileSystem() {} + +bool MemmappedFileSystem::FileExists(const string& fname) { + if (!mapped_memory_) { + return false; + } + const auto dir_element = directory_.find(fname); + return dir_element != directory_.end(); +} + +Status MemmappedFileSystem::NewRandomAccessFile(const string& filename, + RandomAccessFile** result) { + if (!mapped_memory_) { + return errors::FailedPrecondition("MemmappedEnv is not initialized"); + } + const auto dir_element = directory_.find(filename); + if (dir_element == directory_.end()) { + return errors::NotFound("Region ", filename, " is not found"); + } + *result = new RandomAccessFileFromMemmapped( + GetMemoryWithOffset(dir_element->second.offset), + dir_element->second.length); + return Status::OK(); +} + +Status MemmappedFileSystem::NewReadOnlyMemoryRegionFromFile( + const string& filename, ReadOnlyMemoryRegion** result) { + if (!mapped_memory_) { + return errors::FailedPrecondition("MemmappedEnv is not initialized"); + } + const auto dir_element = directory_.find(filename); + if (dir_element == directory_.end()) { + return errors::NotFound("Region ", filename, " is not found"); + } + *result = new ReadOnlyMemoryRegionFromMemmapped( + GetMemoryWithOffset(dir_element->second.offset), + dir_element->second.length); + return Status::OK(); +} + +Status MemmappedFileSystem::GetFileSize(const string& filename, uint64* size) { + if (!mapped_memory_) { + return errors::FailedPrecondition("MemmappedEnv is not initialized"); + } + const auto dir_element = directory_.find(filename); + if (dir_element == directory_.end()) { + return errors::NotFound("Region ", filename, " is not found"); + } + *size = dir_element->second.length; + return Status::OK(); +} + +Status MemmappedFileSystem::NewWritableFile(const string& filename, + WritableFile** wf) { + return errors::Unimplemented("memmapped format doesn't support writing"); +} + +Status MemmappedFileSystem::NewAppendableFile(const string& filename, + WritableFile** result) { + return errors::Unimplemented("memmapped format doesn't support writing"); +} + +Status MemmappedFileSystem::GetChildren(const string& filename, + std::vector* strings) { + return errors::Unimplemented("memmapped format doesn't support GetChildren"); +} + +Status MemmappedFileSystem::DeleteFile(const string& filename) { + return errors::Unimplemented("memmapped format doesn't support DeleteFile"); +} + +Status MemmappedFileSystem::CreateDir(const string& dirname) { + return errors::Unimplemented("memmapped format doesn't support CreateDir"); +} + +Status MemmappedFileSystem::DeleteDir(const string& dirname) { + return errors::Unimplemented("memmapped format doesn't support DeleteDir"); +} + +Status MemmappedFileSystem::RenameFile(const string& filename_from, + const string& filename_to) { + return errors::Unimplemented("memmapped format doesn't support RenameFile"); +} + +const void* MemmappedFileSystem::GetMemoryWithOffset(uint64 offset) const { + return reinterpret_cast(mapped_memory_->data()) + offset; +} + +constexpr char MemmappedFileSystem::kMemmappedPackagePrefix[]; +constexpr char MemmappedFileSystem::kMemmappedPackageDefaultGraphDef[]; + +Status MemmappedFileSystem::InitializeFromFile(Env* env, + const string& filename) { + ReadOnlyMemoryRegion* region; + TF_RETURN_IF_ERROR(env->NewReadOnlyMemoryRegionFromFile(filename, ®ion)); + mapped_memory_.reset(region); + directory_.clear(); + if (mapped_memory_->length() <= sizeof(uint64)) { + return errors::DataLoss("Corrupted memmapped model file: ", filename, + " Invalid package size"); + } + const auto memory_start = + reinterpret_cast(mapped_memory_->data()); + const uint64 directory_offset = DecodeUint64LittleEndian( + memory_start + mapped_memory_->length() - sizeof(uint64)); + if (directory_offset > mapped_memory_->length() - sizeof(uint64)) { + return errors::DataLoss("Corrupted memmapped model file: ", filename, + " Invalid directory offset"); + } + MemmappedFileSystemDirectory proto_directory; + if (!ParseProtoUnlimited( + &proto_directory, memory_start + directory_offset, + mapped_memory_->length() - directory_offset - sizeof(uint64))) { + return errors::DataLoss("Corrupted memmapped model file: ", filename, + " Can't parse its internal directory"); + } + + // Iterating in reverse order to get lengths of elements; + uint64 prev_element_offset = directory_offset; + for (auto element_iter = proto_directory.element().rbegin(); + element_iter != proto_directory.element().rend(); ++element_iter) { + // Check that the element offset is in the right range. + if (element_iter->offset() >= prev_element_offset) { + return errors::DataLoss("Corrupted memmapped model file: ", filename, + " Invalid offset of internal component"); + } + if (!directory_ + .insert(std::make_pair( + element_iter->name(), + FileRegion(element_iter->offset(), + prev_element_offset - element_iter->offset()))) + .second) { + return errors::DataLoss("Corrupted memmapped model file: ", filename, + " Duplicate name of internal component ", + element_iter->name()); + } + prev_element_offset = element_iter->offset(); + } + return Status::OK(); +} + +bool MemmappedFileSystem::IsMemmappedPackageFilename(const string& filename) { + return StringPiece(filename).starts_with(kMemmappedPackagePrefix); +} + +namespace { +bool IsValidRegionChar(char c) { + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || c == '_' || c == '.'; +} +} // namespace + +bool MemmappedFileSystem::IsWellFormedMemmappedPackageFilename( + const string& filename) { + if (!IsMemmappedPackageFilename(filename)) { + return false; + } + for (char c : + filename.substr(strlen(kMemmappedPackagePrefix), + filename.length() - strlen(kMemmappedPackagePrefix))) { + if (!IsValidRegionChar(c)) { + return false; + } + } + return true; +} + +MemmappedEnv::MemmappedEnv(Env* env) : EnvWrapper(env) {} + +Status MemmappedEnv::GetFileSystemForFile(const string& fname, + FileSystem** result) { + if (MemmappedFileSystem::IsMemmappedPackageFilename(fname)) { + if (!memmapped_file_system_) { + return errors::FailedPrecondition( + "MemmappedEnv is not initialized from a file."); + } + *result = memmapped_file_system_.get(); + return Status::OK(); + } + return EnvWrapper::GetFileSystemForFile(fname, result); +} + +Status MemmappedEnv::GetRegisteredFileSystemSchemes( + std::vector* schemes) { + const auto status = EnvWrapper::GetRegisteredFileSystemSchemes(schemes); + if (status.ok()) { + schemes->emplace_back(MemmappedFileSystem::kMemmappedPackagePrefix); + } + return status; +} + +Status MemmappedEnv::InitializeFromFile(const string& package_filename) { + std::unique_ptr file_system_ptr(new MemmappedFileSystem); + const auto status = + file_system_ptr->InitializeFromFile(target(), package_filename); + if (status.ok()) { + memmapped_file_system_ = std::move(file_system_ptr); + } + return status; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/memmapped_file_system.h b/tensorflow/core/util/memmapped_file_system.h new file mode 100644 index 00000000000..29356fbd959 --- /dev/null +++ b/tensorflow/core/util/memmapped_file_system.h @@ -0,0 +1,121 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +// A file system that uses a graph saved in memmapped format by +// MemmappedEnvWriter as a file system. +// +// The format supports saved tensors and protos. Tensors are saved at aligned +// offsets. +// +// Format specification: +// - last 8 bytes of a package is encoded offset to the directory. The encoding +// is always little endian, independently from the platform, done by functions +// EncodeUint64LittleEndian/DecodeUint64LittleEndian +// - the directory starts from the encoded offset and is saved proto +// MemmappedFileSystemDirectory with names and offsets to the regions. +// - at the offsets in the directory the file regions are stored. Tensor regions +// are aligned such way that when the package mapped to RAM they have the right +// offset to be used by ImmutableConst operator. +// +// Region naming: +// Region naming is up to the application, all of them starts from +// kMemmappedPackagePrefix. The default graph usually has name +// kMemmappedPackageDefaultGraphDef; for more details see the conversion +// utility +// third_party/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc +// +// A "frozen" GraphDef can be converted into this format using +// tensorflow/contrib/util/convert_graphdef_memmapped_format +class MemmappedFileSystem : public FileSystem { + public: + // Memmapped regions use this prefix to distinguish from + // the filesystem. + static constexpr char kMemmappedPackagePrefix[] = "memmapped_package://"; + // The default graphdef in the package. + static constexpr char kMemmappedPackageDefaultGraphDef[] = + "memmapped_package://."; + + MemmappedFileSystem(); + ~MemmappedFileSystem() override = default; + bool FileExists(const string& fname) override; + Status NewRandomAccessFile(const string& filename, + RandomAccessFile** result) override; + Status NewReadOnlyMemoryRegionFromFile( + const string& filename, ReadOnlyMemoryRegion** result) override; + + // All these functions return Unimplemented error, the memmapped storage is + // read only. + Status NewWritableFile(const string& fname, WritableFile** result) override; + Status NewAppendableFile(const string& fname, WritableFile** result) override; + Status GetChildren(const string& dir, std::vector* r) override; + Status DeleteFile(const string& f) override; + Status CreateDir(const string& d) override; + Status DeleteDir(const string& d) override; + Status GetFileSize(const string& f, uint64* s) override; + Status RenameFile(const string& s, const string& t) override; + + // Initializes filesystem from a file in memmapped format. + Status InitializeFromFile(Env* env, const string& filename); + + // Checks if the filename has a correct prefix. + static bool IsMemmappedPackageFilename(const string& filename); + + static bool IsWellFormedMemmappedPackageFilename(const string& filename); + + private: + struct FileRegion { + FileRegion(uint64 o, uint64 l) : offset(o), length(l) {} + + uint64 offset; // Offset from the beginning of the file. + uint64 length; // Length of the region. + }; + + using DirectoryType = std::unordered_map; + + const void* GetMemoryWithOffset(uint64 offset) const; + + std::unique_ptr mapped_memory_; + DirectoryType directory_; + + TF_DISALLOW_COPY_AND_ASSIGN(MemmappedFileSystem); +}; + +class MemmappedEnv : public EnvWrapper { + public: + explicit MemmappedEnv(Env* env); + ~MemmappedEnv() override = default; + Status GetFileSystemForFile(const string& fname, + FileSystem** result) override; + Status GetRegisteredFileSystemSchemes(std::vector* schemes) override; + Status InitializeFromFile(const string& filename); + + protected: + std::unique_ptr memmapped_file_system_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_H_ diff --git a/tensorflow/core/util/memmapped_file_system.proto b/tensorflow/core/util/memmapped_file_system.proto new file mode 100644 index 00000000000..54df54e60b4 --- /dev/null +++ b/tensorflow/core/util/memmapped_file_system.proto @@ -0,0 +1,29 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +syntax = "proto3"; + +package tensorflow; +// option cc_enable_arenas = true; + +// A message that describes one region of memmapped file. +message MemmappedFileSystemDirectoryElement { + uint64 offset = 1; + string name = 2; +} + +// A directory of regions in a memmapped file. +message MemmappedFileSystemDirectory { + repeated MemmappedFileSystemDirectoryElement element = 1; +} diff --git a/tensorflow/core/util/memmapped_file_system_test.cc b/tensorflow/core/util/memmapped_file_system_test.cc new file mode 100644 index 00000000000..67a06824cd4 --- /dev/null +++ b/tensorflow/core/util/memmapped_file_system_test.cc @@ -0,0 +1,150 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/util/memmapped_file_system.h" + +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/memmapped_file_system_writer.h" + +namespace tensorflow { + +namespace { + +// Names of files in memmapped environment. +constexpr char kTensor1FileName[] = "memmapped_package://t1"; +constexpr char kTensor2FileName[] = "memmapped_package://t2"; +constexpr char kProtoFileName[] = "memmapped_package://b"; +constexpr int kTestGraphDefVersion = 666; + +Status CreateMemmappedFileSystemFile(const string& filename, bool corrupted, + Tensor* test_tensor) { + Env* env = Env::Default(); + MemmappedFileSystemWriter writer; + TF_RETURN_IF_ERROR(writer.InitializeToFile(env, filename)); + + // Try to write a tensor and proto. + test::FillFn(test_tensor, + [](int i) { return static_cast(i * i); }); + + TF_RETURN_IF_ERROR(writer.SaveTensor(*test_tensor, kTensor1FileName)); + + // Create a proto with some fields. + GraphDef graph_def; + graph_def.set_version(kTestGraphDefVersion); + TF_RETURN_IF_ERROR(writer.SaveProtobuf(graph_def, kProtoFileName)); + + // Save a tensor after the proto to check that alignment works. + test::FillFn(test_tensor, + [](int i) { return static_cast(i * i * i); }); + TF_RETURN_IF_ERROR(writer.SaveTensor(*test_tensor, kTensor2FileName)); + + if (!corrupted) { + // Flush and close the file. + TF_RETURN_IF_ERROR(writer.FlushAndClose()); + } + return Status::OK(); +} + +TEST(MemmappedFileSystemTest, SimpleTest) { + const TensorShape test_tensor_shape = {10, 200}; + Tensor test_tensor(DT_FLOAT, test_tensor_shape); + const string dir = testing::TmpDir(); + const string filename = io::JoinPath(dir, "memmapped_env_test"); + TF_ASSERT_OK(CreateMemmappedFileSystemFile(filename, false, &test_tensor)); + + // Check that we can memmap the created file. + MemmappedEnv memmapped_env(Env::Default()); + TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename)); + // Try to load a proto from the file. + GraphDef test_graph_def; + TF_EXPECT_OK( + ReadBinaryProto(&memmapped_env, kProtoFileName, &test_graph_def)); + EXPECT_EQ(kTestGraphDefVersion, test_graph_def.version()); + // Check that we can correctly get a tensor memory. + ReadOnlyMemoryRegion* memory_region; + TF_ASSERT_OK(memmapped_env.NewReadOnlyMemoryRegionFromFile(kTensor2FileName, + &memory_region)); + std::unique_ptr mem_region_ptr(memory_region); + // The memory region can be bigger but not less than Tensor size. + ASSERT_GE(memory_region->length(), test_tensor.TotalBytes()); + EXPECT_EQ(test_tensor.tensor_data(), + StringPiece(static_cast(memory_region->data()), + test_tensor.TotalBytes())); + // Check that GetFileSize works. + uint64 file_size = 0; + TF_ASSERT_OK(memmapped_env.GetFileSize(kTensor2FileName, &file_size)); + EXPECT_EQ(test_tensor.TotalBytes(), file_size); + + // Check that if file not found correct error message returned. + EXPECT_EQ( + error::NOT_FOUND, + memmapped_env.NewReadOnlyMemoryRegionFromFile("bla-bla", &memory_region) + .code()); + + // Check FileExists. + EXPECT_TRUE(memmapped_env.FileExists(kTensor2FileName)); + EXPECT_FALSE(memmapped_env.FileExists("bla-bla-bla")); +} + +TEST(MemmappedFileSystemTest, NotInitalized) { + MemmappedEnv memmapped_env(Env::Default()); + ReadOnlyMemoryRegion* memory_region; + EXPECT_EQ( + error::FAILED_PRECONDITION, + memmapped_env + .NewReadOnlyMemoryRegionFromFile(kTensor1FileName, &memory_region) + .code()); + RandomAccessFile* file; + EXPECT_EQ(error::FAILED_PRECONDITION, + memmapped_env.NewRandomAccessFile(kProtoFileName, &file).code()); +} + +TEST(MemmappedFileSystemTest, Corrupted) { + // Create a corrupted file (it is not closed it properly). + const TensorShape test_tensor_shape = {100, 200}; + Tensor test_tensor(DT_FLOAT, test_tensor_shape); + const string dir = testing::TmpDir(); + const string filename = io::JoinPath(dir, "memmapped_env_corrupted_test"); + TF_ASSERT_OK(CreateMemmappedFileSystemFile(filename, true, &test_tensor)); + MemmappedFileSystem memmapped_env; + ASSERT_NE(memmapped_env.InitializeFromFile(Env::Default(), filename), + Status::OK()); +} + +TEST(MemmappedFileSystemTest, ProxyToDefault) { + MemmappedEnv memmapped_env(Env::Default()); + const string dir = testing::TmpDir(); + const string filename = io::JoinPath(dir, "test_file"); + // Check that we can create write and read ordinary file. + WritableFile* writable_file; + TF_ASSERT_OK(memmapped_env.NewAppendableFile(filename, &writable_file)); + std::unique_ptr writable_file_ptr(writable_file); + const string test_string = "bla-bla-bla"; + TF_ASSERT_OK(writable_file->Append(test_string)); + TF_ASSERT_OK(writable_file->Close()); + uint64 file_length = 0; + TF_EXPECT_OK(memmapped_env.GetFileSize(filename, &file_length)); + EXPECT_EQ(test_string.length(), file_length); + RandomAccessFile* random_access_file; + TF_ASSERT_OK( + memmapped_env.NewRandomAccessFile(filename, &random_access_file)); + delete random_access_file; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/memmapped_file_system_writer.cc b/tensorflow/core/util/memmapped_file_system_writer.cc new file mode 100644 index 00000000000..294f997aab8 --- /dev/null +++ b/tensorflow/core/util/memmapped_file_system_writer.cc @@ -0,0 +1,136 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/util/memmapped_file_system_writer.h" + +#include + +namespace tensorflow { + +Status MemmappedFileSystemWriter::InitializeToFile(Env* env, + const string& filename) { + WritableFile* writable_file; + auto status = env->NewWritableFile(filename, &writable_file); + if (status.ok()) { + output_file_.reset(writable_file); + output_file_offset_ = 0; + } + return status; +} + +Status MemmappedFileSystemWriter::SaveTensor(const Tensor& tensor, + const string& element_name) { + if (!output_file_) { + return errors::FailedPrecondition( + "MemmappedEnvWritter: saving tensor into not opened file"); + } + if (!MemmappedFileSystem::IsWellFormedMemmappedPackageFilename( + element_name)) { + return errors::InvalidArgument( + "MemmappedEnvWritter: element_name is invalid: must have memmapped ", + "package prefix ", MemmappedFileSystem::kMemmappedPackagePrefix, + " and include [A-Za-z0-9_.]"); + } + const auto tensor_data = tensor.tensor_data(); + if (0 == tensor_data.size()) { + return errors::InvalidArgument( + "MemmappedEnvWritter: saving tensor with 0 size"); + } + // Adds pad for correct alignment after memmapping. + TF_RETURN_IF_ERROR(AdjustAlignment(Allocator::kAllocatorAlignment)); + AddToDirectoryElement(element_name); + const auto result = output_file_->Append(tensor_data); + if (result.ok()) { + output_file_offset_ += tensor_data.size(); + } + return result; +} + +Status MemmappedFileSystemWriter::SaveProtobuf( + const protobuf::MessageLite& message, const string& element_name) { + if (!output_file_) { + return errors::FailedPrecondition( + "MemmappedEnvWritter: saving protobuf into not opened file"); + } + if (!MemmappedFileSystem::IsWellFormedMemmappedPackageFilename( + element_name)) { + return errors::InvalidArgument( + "MemmappedEnvWritter: element_name is invalid: must have memmapped " + "package prefix ", + MemmappedFileSystem::kMemmappedPackagePrefix, + " and include [A-Za-z0-9_.]"); + } + AddToDirectoryElement(element_name); + const string encoded = message.SerializeAsString(); + const auto res = output_file_->Append(encoded); + if (res.ok()) { + output_file_offset_ += encoded.size(); + } + return res; +} + +namespace { + +StringPiece EncodeUint64LittleEndian(uint64 val, char* output_buffer) { + for (int i = 0; i < sizeof(uint64); ++i) { + output_buffer[i] = (val >> i * 8); + } + return {output_buffer, sizeof(uint64)}; +} + +} // namespace + +Status MemmappedFileSystemWriter::FlushAndClose() { + if (!output_file_) { + return errors::FailedPrecondition( + "MemmappedEnvWritter: flushing into not opened file"); + } + const string dir = directory_.SerializeAsString(); + TF_RETURN_IF_ERROR(output_file_->Append(dir)); + + // Write the directory offset. + char buffer[sizeof(uint64)]; + TF_RETURN_IF_ERROR(output_file_->Append( + EncodeUint64LittleEndian(output_file_offset_, buffer))); + + // Flush and close the file. + TF_RETURN_IF_ERROR(output_file_->Flush()); + TF_RETURN_IF_ERROR(output_file_->Close()); + output_file_.reset(); + return Status::OK(); +} + +Status MemmappedFileSystemWriter::AdjustAlignment(uint64 alignment) { + const uint64 alignment_rest = output_file_offset_ % alignment; + const uint64 to_write_for_alignment = + (alignment_rest == 0) ? 0 : alignment - (output_file_offset_ % alignment); + static constexpr uint64 kFillerBufferSize = 16; + const char kFillerBuffer[kFillerBufferSize] = {}; + for (uint64 rest = to_write_for_alignment; rest > 0;) { + StringPiece sp(kFillerBuffer, std::min(rest, kFillerBufferSize)); + TF_RETURN_IF_ERROR(output_file_->Append(sp)); + rest -= sp.size(); + output_file_offset_ += sp.size(); + } + return Status::OK(); +} + +void MemmappedFileSystemWriter::AddToDirectoryElement(const string& name) { + MemmappedFileSystemDirectoryElement* new_directory_element = + directory_.add_element(); + new_directory_element->set_offset(output_file_offset_); + new_directory_element->set_name(name); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/memmapped_file_system_writer.h b/tensorflow/core/util/memmapped_file_system_writer.h new file mode 100644 index 00000000000..47461eaebee --- /dev/null +++ b/tensorflow/core/util/memmapped_file_system_writer.h @@ -0,0 +1,53 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_WRITER_H_ +#define TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_WRITER_H_ + +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/util/memmapped_file_system.h" +#include "tensorflow/core/util/memmapped_file_system.pb.h" + +namespace tensorflow { + +// A class for saving into the memmapped format that can be read by +// MemmappedFileSystem. +class MemmappedFileSystemWriter { + public: + MemmappedFileSystemWriter() = default; + ~MemmappedFileSystemWriter() = default; + Status InitializeToFile(Env* env, const string& filename); + Status SaveTensor(const Tensor& tensor, const string& element_name); + Status SaveProtobuf(const protobuf::MessageLite& message, + const string& element_name); + // Writes out the directory of regions and closes the output file. + Status FlushAndClose(); + + private: + Status AdjustAlignment(uint64 alignment); + void AddToDirectoryElement(const string& element_name); + MemmappedFileSystemDirectory directory_; + // The current offset in the file, to support alignment. + uint64 output_file_offset_ = 0; + std::unique_ptr output_file_; + TF_DISALLOW_COPY_AND_ASSIGN(MemmappedFileSystemWriter); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_MEMMAPPED_FILE_SYSTEM_WRITER_H_ diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 0512ef835e9..011eef40d93 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -61,6 +61,7 @@ def tf_android_core_proto_sources_relative(): "lib/core/error_codes.proto", "protobuf/config.proto", "protobuf/saver.proto", + "util/memmapped_file_system.proto", "util/saved_tensor_slice.proto", "util/test_log.proto", ]