diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 874f1f2092c..59cc6d68732 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -86,6 +86,7 @@ exports_files( "mutex.h", "net.h", "numa.h", + "ram_file_system.h", "resource_loader.h", "resource.h", "snappy.h", @@ -382,6 +383,15 @@ cc_library( ], ) +py_test( + name = "ram_file_system_test", + srcs = ["ram_file_system_test.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + cc_library( name = "numbers", srcs = ["numbers.cc"], @@ -1293,6 +1303,7 @@ filegroup( "profile_utils/cpu_utils.h", "profile_utils/i_cpu_utils_helper.h", "protobuf.h", + "ram_file_system.h", "random.h", "resource.h", "stacktrace.h", @@ -1507,6 +1518,7 @@ filegroup( "protobuf.cc", "protobuf.h", "protobuf_util.cc", + "ram_file_system.h", "raw_coding.h", "refcount.h", "resource.h", diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD index db714938a45..49318fd0811 100644 --- a/tensorflow/core/platform/default/BUILD +++ b/tensorflow/core/platform/default/BUILD @@ -87,6 +87,7 @@ cc_library( "//tensorflow/core/platform:env.h", "//tensorflow/core/platform:file_system.h", "//tensorflow/core/platform:file_system_helper.h", + "//tensorflow/core/platform:ram_file_system.h", "//tensorflow/core/platform:threadpool.h", ], tags = [ diff --git a/tensorflow/core/platform/default/env.cc b/tensorflow/core/platform/default/env.cc index 5f7822f6583..e33155a4414 100644 --- a/tensorflow/core/platform/default/env.cc +++ b/tensorflow/core/platform/default/env.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/platform/load_library.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/ram_file_system.h" #include "tensorflow/core/protobuf/error_codes.pb.h" namespace tensorflow { @@ -214,6 +215,8 @@ class PosixEnv : public Env { #if defined(PLATFORM_POSIX) || defined(__APPLE__) || defined(__ANDROID__) REGISTER_FILE_SYSTEM("", PosixFileSystem); REGISTER_FILE_SYSTEM("file", LocalPosixFileSystem); +REGISTER_FILE_SYSTEM("ram", RamFileSystem); + Env* Env::Default() { static Env* default_env = new PosixEnv; return default_env; diff --git a/tensorflow/core/platform/ram_file_system.h b/tensorflow/core/platform/ram_file_system.h new file mode 100644 index 00000000000..abd673b455c --- /dev/null +++ b/tensorflow/core/platform/ram_file_system.h @@ -0,0 +1,232 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_RAM_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_PLATFORM_RAM_FILE_SYSTEM_H_ + +// Implementation of an in-memory TF filesystem for simple prototyping (e.g. +// via Colab). The TPU TF server does not have local filesystem access, which +// makes it difficult to provide Colab tutorials: users must have GCS access +// and sign-in in order to try out an example. +// +// Files are implemented on top of std::string. Directories, as with GCS or S3, +// are implicit based on the existence of child files. Multiple files may +// reference a single FS location, though no thread-safety guarantees are +// provided. + +#include + +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +#ifdef PLATFORM_WINDOWS +#undef DeleteFile +#undef CopyFile +#undef TranslateName +#endif + +namespace tensorflow { + +class RamRandomAccessFile : public RandomAccessFile, public WritableFile { + public: + RamRandomAccessFile(std::string name, std::shared_ptr cord) + : name_(name), data_(cord) {} + ~RamRandomAccessFile() override {} + + Status Name(StringPiece* result) const override { + *result = name_; + return Status::OK(); + } + + Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override { + if (offset >= data_->size()) { + return errors::OutOfRange(""); + } + + uint64 left = std::min(static_cast(n), data_->size() - offset); + auto start = data_->begin() + offset; + auto end = data_->begin() + offset + left; + + std::copy(start, end, scratch); + *result = StringPiece(scratch, left); + + // In case of a partial read, we must still fill `result`, but also return + // OutOfRange. + if (left < n) { + return errors::OutOfRange(""); + } + return Status::OK(); + } + + Status Append(StringPiece data) override { + data_->append(data.data(), data.size()); + return Status::OK(); + } + + Status Close() override { return Status::OK(); } + Status Flush() override { return Status::OK(); } + Status Sync() override { return Status::OK(); } + + Status Tell(int64* position) override { + *position = -1; + return errors::Unimplemented("This filesystem does not support Tell()"); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RamRandomAccessFile); + string name_; + std::shared_ptr data_; +}; + +class RamFileSystem : public FileSystem { + public: + Status NewRandomAccessFile( + const string& fname, std::unique_ptr* result) override { + mutex_lock m(mu_); + if (fs_.find(fname) == fs_.end()) { + return errors::NotFound(""); + } + *result = std::unique_ptr( + new RamRandomAccessFile(fname, fs_[fname])); + return Status::OK(); + } + + Status NewWritableFile(const string& fname, + std::unique_ptr* result) override { + mutex_lock m(mu_); + if (fs_.find(fname) == fs_.end()) { + fs_[fname] = std::make_shared(); + } + *result = std::unique_ptr( + new RamRandomAccessFile(fname, fs_[fname])); + return Status::OK(); + } + Status NewAppendableFile(const string& fname, + std::unique_ptr* result) override { + mutex_lock m(mu_); + if (fs_.find(fname) == fs_.end()) { + fs_[fname] = std::make_shared(); + } + *result = std::unique_ptr( + new RamRandomAccessFile(fname, fs_[fname])); + return Status::OK(); + } + + Status NewReadOnlyMemoryRegionFromFile( + const string& fname, + std::unique_ptr* result) override { + return errors::Unimplemented(""); + } + + Status FileExists(const string& fname) override { + FileStatistics stat; + return Stat(fname, &stat); + } + + Status GetChildren(const string& dir, std::vector* result) override { + mutex_lock m(mu_); + auto it = fs_.lower_bound(dir); + while (it != fs_.end() && absl::StartsWith(it->first, dir)) { + result->push_back(it->first); + ++it; + } + + return Status::OK(); + } + + Status GetMatchingPaths(const string& pattern, + std::vector* results) override { + mutex_lock m(mu_); + Env* env = Env::Default(); + for (auto it = fs_.begin(); it != fs_.end(); ++it) { + if (env->MatchPath(it->first, pattern)) { + results->push_back(it->first); + } + } + return Status::OK(); + } + + Status Stat(const string& fname, FileStatistics* stat) override { + mutex_lock m(mu_); + auto it = fs_.lower_bound(fname); + if (it == fs_.end()) { + return errors::NotFound(""); + } + + if (it->first == fname) { + stat->is_directory = false; + stat->length = fs_[fname]->size(); + stat->mtime_nsec = 0; + return Status::OK(); + } + + stat->is_directory = true; + stat->length = 0; + stat->mtime_nsec = 0; + return Status::OK(); + } + + Status DeleteFile(const string& fname) override { + mutex_lock m(mu_); + if (fs_.find(fname) != fs_.end()) { + fs_.erase(fname); + return Status::OK(); + } + + return errors::NotFound(""); + } + + Status CreateDir(const string& dirname) override { return Status::OK(); } + + Status RecursivelyCreateDir(const string& dirname) override { + return Status::OK(); + } + + Status DeleteDir(const string& dirname) override { return Status::OK(); } + + Status GetFileSize(const string& fname, uint64* file_size) override { + mutex_lock m(mu_); + if (fs_.find(fname) != fs_.end()) { + *file_size = fs_[fname]->size(); + return Status::OK(); + } + return errors::NotFound(""); + } + + Status RenameFile(const string& src, const string& target) override { + mutex_lock m(mu_); + if (fs_.find(src) != fs_.end()) { + fs_[target] = fs_[src]; + fs_.erase(fs_.find(src)); + return Status::OK(); + } + return errors::NotFound(""); + } + + RamFileSystem() {} + ~RamFileSystem() override {} + + private: + mutex mu_; + std::map> fs_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_RAM_FILE_SYSTEM_H_ diff --git a/tensorflow/core/platform/ram_file_system_test.py b/tensorflow/core/platform/ram_file_system_test.py new file mode 100644 index 00000000000..0f4f47ec44e --- /dev/null +++ b/tensorflow/core/platform/ram_file_system_test.py @@ -0,0 +1,119 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for ram_file_system.h.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.estimator.estimator import Estimator +from tensorflow.python.estimator.model_fn import EstimatorSpec +from tensorflow.python.estimator.run_config import RunConfig +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.training import adam +from tensorflow.python.training import training_util + + +class RamFilesystemTest(test_util.TensorFlowTestCase): + + def test_write_file(self): + with gfile.GFile('ram://a.txt', 'w') as f: + f.write('Hello, world.') + f.write('Hello, world.') + + with gfile.GFile('ram://a.txt', 'r') as f: + self.assertEqual(f.read(), 'Hello, world.' * 2) + + def test_append_file_with_seek(self): + with gfile.GFile('ram://c.txt', 'w') as f: + f.write('Hello, world.') + + with gfile.GFile('ram://c.txt', 'w+') as f: + f.seek(offset=0, whence=2) + f.write('Hello, world.') + + with gfile.GFile('ram://c.txt', 'r') as f: + self.assertEqual(f.read(), 'Hello, world.' * 2) + + def test_list_dir(self): + for i in range(10): + with gfile.GFile('ram://a/b/%d.txt' % i, 'w') as f: + f.write('') + with gfile.GFile('ram://c/b/%d.txt' % i, 'w') as f: + f.write('') + + matches = ['ram://a/b/%d.txt' % i for i in range(10)] + self.assertEqual(gfile.ListDirectory('ram://a/b/'), matches) + + def test_glob(self): + for i in range(10): + with gfile.GFile('ram://a/b/%d.txt' % i, 'w') as f: + f.write('') + with gfile.GFile('ram://c/b/%d.txt' % i, 'w') as f: + f.write('') + + matches = ['ram://a/b/%d.txt' % i for i in range(10)] + self.assertEqual(gfile.Glob('ram://a/b/*'), matches) + + matches = [] + self.assertEqual(gfile.Glob('ram://b/b/*'), matches) + + matches = ['ram://c/b/%d.txt' % i for i in range(10)] + self.assertEqual(gfile.Glob('ram://c/b/*'), matches) + + def test_estimator(self): + + def model_fn(features, labels, mode, params): + del params + x = core_layers.dense(features, 100) + x = core_layers.dense(x, 100) + x = core_layers.dense(x, 100) + x = core_layers.dense(x, 100) + y = core_layers.dense(x, 1) + loss = losses.mean_squared_error(labels, y) + opt = adam.AdamOptimizer(learning_rate=0.1) + train_op = opt.minimize( + loss, global_step=training_util.get_or_create_global_step()) + + return EstimatorSpec(mode=mode, loss=loss, train_op=train_op) + + def input_fn(): + batch_size = 128 + return (constant_op.constant(np.random.randn(batch_size, 100), + dtype=dtypes.float32), + constant_op.constant(np.random.randn(batch_size, 1), + dtype=dtypes.float32)) + + config = RunConfig( + model_dir='ram://estimator-0/', save_checkpoints_steps=1) + estimator = Estimator(config=config, model_fn=model_fn) + + estimator.train(input_fn=input_fn, steps=10) + estimator.train(input_fn=input_fn, steps=10) + estimator.train(input_fn=input_fn, steps=10) + estimator.train(input_fn=input_fn, steps=10) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/core/platform/windows/BUILD b/tensorflow/core/platform/windows/BUILD index caba7db0d61..dddb4b9aed4 100644 --- a/tensorflow/core/platform/windows/BUILD +++ b/tensorflow/core/platform/windows/BUILD @@ -24,6 +24,7 @@ cc_library( "//tensorflow/core/platform:env.cc", "//tensorflow/core/platform:file_system.cc", "//tensorflow/core/platform:file_system_helper.cc", + "//tensorflow/core/platform:ram_file_system.h", "//tensorflow/core/platform:threadpool.cc", ], hdrs = [ diff --git a/tensorflow/core/platform/windows/env.cc b/tensorflow/core/platform/windows/env.cc index 843f41765ef..d75d2d5773d 100644 --- a/tensorflow/core/platform/windows/env.cc +++ b/tensorflow/core/platform/windows/env.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/load_library.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/ram_file_system.h" #include "tensorflow/core/platform/windows/wide_char.h" #include "tensorflow/core/platform/windows/windows_file_system.h" #include "tensorflow/core/protobuf/error_codes.pb.h" @@ -192,6 +193,7 @@ class WindowsEnv : public Env { REGISTER_FILE_SYSTEM("", WindowsFileSystem); REGISTER_FILE_SYSTEM("file", LocalWinFileSystem); +REGISTER_FILE_SYSTEM("ram", RamFileSystem); Env* Env::Default() { static Env* default_env = new WindowsEnv;