Fixed a bug in RamFileSystem::FileExists

PiperOrigin-RevId: 329022084
Change-Id: Iaff54394db78fc3818cacf28f12971fcbc9d8529
This commit is contained in:
Zhuo Peng 2020-08-28 15:39:08 -07:00 committed by TensorFlower Gardener
parent ee4736e5bf
commit 49b58c7b7f
2 changed files with 27 additions and 1 deletions

View File

@ -177,7 +177,7 @@ class RamFileSystem : public FileSystem {
FileStatistics* stat) override {
mutex_lock m(mu_);
auto it = fs_.lower_bound(fname);
if (it == fs_.end()) {
if (it == fs_.end() || !absl::StartsWith(it->first, fname)) {
return errors::NotFound("");
}

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.eager import def_function
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.python.estimator.run_config import RunConfig
@ -28,9 +29,11 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core as core_layers
from tensorflow.python.module import module
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training import adam
from tensorflow.python.training import training_util
@ -82,6 +85,17 @@ class RamFilesystemTest(test_util.TensorFlowTestCase):
matches = ['ram://c/b/%d.txt' % i for i in range(10)]
self.assertEqual(gfile.Glob('ram://c/b/*'), matches)
def test_file_exists(self):
with gfile.GFile('ram://exists/a/b/c.txt', 'w') as f:
f.write('')
self.assertTrue(gfile.Exists('ram://exists/a'))
self.assertTrue(gfile.Exists('ram://exists/a/b'))
self.assertTrue(gfile.Exists('ram://exists/a/b/c.txt'))
self.assertFalse(gfile.Exists('ram://exists/b'))
self.assertFalse(gfile.Exists('ram://exists/a/c'))
self.assertFalse(gfile.Exists('ram://exists/a/b/k'))
def test_estimator(self):
def model_fn(features, labels, mode, params):
@ -114,6 +128,18 @@ class RamFilesystemTest(test_util.TensorFlowTestCase):
estimator.train(input_fn=input_fn, steps=10)
estimator.train(input_fn=input_fn, steps=10)
def test_savedmodel(self):
class MyModule(module.Module):
@def_function.function(input_signature=[])
def foo(self):
return constant_op.constant([1])
saved_model.save(MyModule(), 'ram://my_module')
loaded = saved_model.load('ram://my_module')
self.assertAllEqual(loaded.foo(), [1])
if __name__ == '__main__':
test.main()