Introduce a RamFileSystem for TensorFlow
PiperOrigin-RevId: 308438007 Change-Id: I30afcf7e897659951fdf4b171852923038d2aebe
This commit is contained in:
parent
5f6fe1d508
commit
ecaa695cb1
@ -86,7 +86,6 @@ exports_files(
|
||||
"mutex.h",
|
||||
"net.h",
|
||||
"numa.h",
|
||||
"ram_file_system.h",
|
||||
"resource_loader.h",
|
||||
"resource.h",
|
||||
"snappy.h",
|
||||
@ -383,15 +382,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ram_file_system_test",
|
||||
srcs = ["ram_file_system_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "numbers",
|
||||
srcs = ["numbers.cc"],
|
||||
@ -1303,7 +1293,6 @@ filegroup(
|
||||
"profile_utils/cpu_utils.h",
|
||||
"profile_utils/i_cpu_utils_helper.h",
|
||||
"protobuf.h",
|
||||
"ram_file_system.h",
|
||||
"random.h",
|
||||
"resource.h",
|
||||
"stacktrace.h",
|
||||
@ -1518,7 +1507,6 @@ filegroup(
|
||||
"protobuf.cc",
|
||||
"protobuf.h",
|
||||
"protobuf_util.cc",
|
||||
"ram_file_system.h",
|
||||
"raw_coding.h",
|
||||
"refcount.h",
|
||||
"resource.h",
|
||||
|
@ -87,7 +87,6 @@ cc_library(
|
||||
"//tensorflow/core/platform:env.h",
|
||||
"//tensorflow/core/platform:file_system.h",
|
||||
"//tensorflow/core/platform:file_system_helper.h",
|
||||
"//tensorflow/core/platform:ram_file_system.h",
|
||||
"//tensorflow/core/platform:threadpool.h",
|
||||
],
|
||||
tags = [
|
||||
|
@ -38,7 +38,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/load_library.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/ram_file_system.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -215,8 +214,6 @@ class PosixEnv : public Env {
|
||||
#if defined(PLATFORM_POSIX) || defined(__APPLE__) || defined(__ANDROID__)
|
||||
REGISTER_FILE_SYSTEM("", PosixFileSystem);
|
||||
REGISTER_FILE_SYSTEM("file", LocalPosixFileSystem);
|
||||
REGISTER_FILE_SYSTEM("ram", RamFileSystem);
|
||||
|
||||
Env* Env::Default() {
|
||||
static Env* default_env = new PosixEnv;
|
||||
return default_env;
|
||||
|
@ -1,224 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_RAM_FILE_SYSTEM_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_RAM_FILE_SYSTEM_H_
|
||||
|
||||
// Implementation of an in-memory TF filesystem for simple prototyping (e.g.
|
||||
// via Colab). The TPU TF server does not have local filesystem access, which
|
||||
// makes it difficult to provide Colab tutorials: users must have GCS access
|
||||
// and sign-in in order to try out an example.
|
||||
//
|
||||
// Files are implemented on top of std::string. Directories, as with GCS or S3,
|
||||
// are implicit based on the existence of child files. Multiple files may
|
||||
// reference a single FS location, though no thread-safety guarantees are
|
||||
// provided.
|
||||
|
||||
#include <fnmatch.h>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RamRandomAccessFile : public RandomAccessFile, public WritableFile {
|
||||
public:
|
||||
RamRandomAccessFile(std::string name, std::shared_ptr<std::string> cord)
|
||||
: name_(name), data_(cord) {}
|
||||
~RamRandomAccessFile() override {}
|
||||
|
||||
Status Name(StringPiece* result) const override {
|
||||
*result = name_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Read(uint64 offset, size_t n, StringPiece* result,
|
||||
char* scratch) const override {
|
||||
if (offset >= data_->size()) {
|
||||
return errors::OutOfRange("");
|
||||
}
|
||||
|
||||
uint64 left = std::min(static_cast<uint64>(n), data_->size() - offset);
|
||||
auto start = data_->begin() + offset;
|
||||
auto end = data_->begin() + offset + left;
|
||||
|
||||
std::copy(start, end, scratch);
|
||||
*result = StringPiece(scratch, left);
|
||||
|
||||
// In case of a partial read, we must still fill `result`, but also return
|
||||
// OutOfRange.
|
||||
if (left < n) {
|
||||
return errors::OutOfRange("");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Append(StringPiece data) override {
|
||||
data_->append(data.data(), data.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Close() override { return Status::OK(); }
|
||||
Status Flush() override { return Status::OK(); }
|
||||
Status Sync() override { return Status::OK(); }
|
||||
|
||||
Status Tell(int64* position) override {
|
||||
*position = -1;
|
||||
return errors::Unimplemented("This filesystem does not support Tell()");
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RamRandomAccessFile);
|
||||
string name_;
|
||||
std::shared_ptr<std::string> data_;
|
||||
};
|
||||
|
||||
class RamFileSystem : public FileSystem {
|
||||
public:
|
||||
Status NewRandomAccessFile(
|
||||
const string& fname, std::unique_ptr<RandomAccessFile>* result) override {
|
||||
mutex_lock m(mu_);
|
||||
if (fs_.find(fname) == fs_.end()) {
|
||||
return errors::NotFound("");
|
||||
}
|
||||
*result = std::unique_ptr<RandomAccessFile>(
|
||||
new RamRandomAccessFile(fname, fs_[fname]));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NewWritableFile(const string& fname,
|
||||
std::unique_ptr<WritableFile>* result) override {
|
||||
mutex_lock m(mu_);
|
||||
if (fs_.find(fname) == fs_.end()) {
|
||||
fs_[fname] = std::make_shared<std::string>();
|
||||
}
|
||||
*result = std::unique_ptr<WritableFile>(
|
||||
new RamRandomAccessFile(fname, fs_[fname]));
|
||||
return Status::OK();
|
||||
}
|
||||
Status NewAppendableFile(const string& fname,
|
||||
std::unique_ptr<WritableFile>* result) override {
|
||||
mutex_lock m(mu_);
|
||||
if (fs_.find(fname) == fs_.end()) {
|
||||
fs_[fname] = std::make_shared<std::string>();
|
||||
}
|
||||
*result = std::unique_ptr<WritableFile>(
|
||||
new RamRandomAccessFile(fname, fs_[fname]));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NewReadOnlyMemoryRegionFromFile(
|
||||
const string& fname,
|
||||
std::unique_ptr<ReadOnlyMemoryRegion>* result) override {
|
||||
return errors::Unimplemented("");
|
||||
}
|
||||
|
||||
Status FileExists(const string& fname) override {
|
||||
FileStatistics stat;
|
||||
return Stat(fname, &stat);
|
||||
}
|
||||
|
||||
Status GetChildren(const string& dir, std::vector<string>* result) override {
|
||||
mutex_lock m(mu_);
|
||||
auto it = fs_.lower_bound(dir);
|
||||
while (it != fs_.end() && absl::StartsWith(it->first, dir)) {
|
||||
result->push_back(it->first);
|
||||
++it;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetMatchingPaths(const string& pattern,
|
||||
std::vector<string>* results) override {
|
||||
mutex_lock m(mu_);
|
||||
for (auto it = fs_.begin(); it != fs_.end(); ++it) {
|
||||
if (fnmatch(pattern.c_str(), it->first.c_str(), FNM_PATHNAME) == 0) {
|
||||
results->push_back(it->first);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Stat(const string& fname, FileStatistics* stat) override {
|
||||
mutex_lock m(mu_);
|
||||
auto it = fs_.lower_bound(fname);
|
||||
if (it == fs_.end()) {
|
||||
return errors::NotFound("");
|
||||
}
|
||||
|
||||
if (it->first == fname) {
|
||||
stat->is_directory = false;
|
||||
stat->length = fs_[fname]->size();
|
||||
stat->mtime_nsec = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
stat->is_directory = true;
|
||||
stat->length = 0;
|
||||
stat->mtime_nsec = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeleteFile(const string& fname) override {
|
||||
mutex_lock m(mu_);
|
||||
if (fs_.find(fname) != fs_.end()) {
|
||||
fs_.erase(fname);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return errors::NotFound("");
|
||||
}
|
||||
|
||||
Status CreateDir(const string& dirname) override { return Status::OK(); }
|
||||
Status RecursivelyCreateDir(const string& dirname) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeleteDir(const string& dirname) override { return Status::OK(); }
|
||||
|
||||
Status GetFileSize(const string& fname, uint64* file_size) override {
|
||||
mutex_lock m(mu_);
|
||||
if (fs_.find(fname) != fs_.end()) {
|
||||
*file_size = fs_[fname]->size();
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::NotFound("");
|
||||
}
|
||||
|
||||
Status RenameFile(const string& src, const string& target) override {
|
||||
mutex_lock m(mu_);
|
||||
if (fs_.find(src) != fs_.end()) {
|
||||
fs_[target] = fs_[src];
|
||||
fs_.erase(fs_.find(src));
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::NotFound("");
|
||||
}
|
||||
|
||||
RamFileSystem() {}
|
||||
~RamFileSystem() override {}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
std::map<string, std::shared_ptr<std::string>> fs_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_RAM_FILE_SYSTEM_H_
|
@ -1,119 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for ram_file_system.h."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.estimator.estimator import Estimator
|
||||
from tensorflow.python.estimator.model_fn import EstimatorSpec
|
||||
from tensorflow.python.estimator.run_config import RunConfig
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import core as core_layers
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
class RamFilesystemTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_write_file(self):
|
||||
with gfile.GFile('ram://a.txt', 'w') as f:
|
||||
f.write('Hello, world.')
|
||||
f.write('Hello, world.')
|
||||
|
||||
with gfile.GFile('ram://a.txt', 'r') as f:
|
||||
self.assertEqual(f.read(), 'Hello, world.' * 2)
|
||||
|
||||
def test_append_file_with_seek(self):
|
||||
with gfile.GFile('ram://c.txt', 'w') as f:
|
||||
f.write('Hello, world.')
|
||||
|
||||
with gfile.GFile('ram://c.txt', 'w+') as f:
|
||||
f.seek(offset=0, whence=2)
|
||||
f.write('Hello, world.')
|
||||
|
||||
with gfile.GFile('ram://c.txt', 'r') as f:
|
||||
self.assertEqual(f.read(), 'Hello, world.' * 2)
|
||||
|
||||
def test_list_dir(self):
|
||||
for i in range(10):
|
||||
with gfile.GFile('ram://a/b/%d.txt' % i, 'w') as f:
|
||||
f.write('')
|
||||
with gfile.GFile('ram://c/b/%d.txt' % i, 'w') as f:
|
||||
f.write('')
|
||||
|
||||
matches = ['ram://a/b/%d.txt' % i for i in range(10)]
|
||||
self.assertEqual(gfile.ListDirectory('ram://a/b/'), matches)
|
||||
|
||||
def test_glob(self):
|
||||
for i in range(10):
|
||||
with gfile.GFile('ram://a/b/%d.txt' % i, 'w') as f:
|
||||
f.write('')
|
||||
with gfile.GFile('ram://c/b/%d.txt' % i, 'w') as f:
|
||||
f.write('')
|
||||
|
||||
matches = ['ram://a/b/%d.txt' % i for i in range(10)]
|
||||
self.assertEqual(gfile.Glob('ram://a/b/*'), matches)
|
||||
|
||||
matches = []
|
||||
self.assertEqual(gfile.Glob('ram://b/b/*'), matches)
|
||||
|
||||
matches = ['ram://c/b/%d.txt' % i for i in range(10)]
|
||||
self.assertEqual(gfile.Glob('ram://c/b/*'), matches)
|
||||
|
||||
def test_estimator(self):
|
||||
|
||||
def model_fn(features, labels, mode, params):
|
||||
del params
|
||||
x = core_layers.dense(features, 100)
|
||||
x = core_layers.dense(x, 100)
|
||||
x = core_layers.dense(x, 100)
|
||||
x = core_layers.dense(x, 100)
|
||||
y = core_layers.dense(x, 1)
|
||||
loss = losses.mean_squared_error(labels, y)
|
||||
opt = adam.AdamOptimizer(learning_rate=0.1)
|
||||
train_op = opt.minimize(
|
||||
loss, global_step=training_util.get_or_create_global_step())
|
||||
|
||||
return EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
|
||||
|
||||
def input_fn():
|
||||
batch_size = 128
|
||||
return (constant_op.constant(np.random.randn(batch_size, 100),
|
||||
dtype=dtypes.float32),
|
||||
constant_op.constant(np.random.randn(batch_size, 1),
|
||||
dtype=dtypes.float32))
|
||||
|
||||
config = RunConfig(
|
||||
model_dir='ram://estimator-0/', save_checkpoints_steps=1)
|
||||
estimator = Estimator(config=config, model_fn=model_fn)
|
||||
|
||||
estimator.train(input_fn=input_fn, steps=10)
|
||||
estimator.train(input_fn=input_fn, steps=10)
|
||||
estimator.train(input_fn=input_fn, steps=10)
|
||||
estimator.train(input_fn=input_fn, steps=10)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -31,7 +31,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/platform/load_library.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/ram_file_system.h"
|
||||
#include "tensorflow/core/platform/windows/wide_char.h"
|
||||
#include "tensorflow/core/platform/windows/windows_file_system.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
@ -193,7 +192,6 @@ class WindowsEnv : public Env {
|
||||
|
||||
REGISTER_FILE_SYSTEM("", WindowsFileSystem);
|
||||
REGISTER_FILE_SYSTEM("file", LocalWinFileSystem);
|
||||
REGISTER_FILE_SYSTEM("ram", RamFileSystem);
|
||||
|
||||
Env* Env::Default() {
|
||||
static Env* default_env = new WindowsEnv;
|
||||
|
Loading…
Reference in New Issue
Block a user