Fixed a bug in RamFileSystem::FileExists
PiperOrigin-RevId: 329022084 Change-Id: Iaff54394db78fc3818cacf28f12971fcbc9d8529
This commit is contained in:
parent
ee4736e5bf
commit
49b58c7b7f
tensorflow/core/platform
@ -177,7 +177,7 @@ class RamFileSystem : public FileSystem {
|
||||
FileStatistics* stat) override {
|
||||
mutex_lock m(mu_);
|
||||
auto it = fs_.lower_bound(fname);
|
||||
if (it == fs_.end()) {
|
||||
if (it == fs_.end() || !absl::StartsWith(it->first, fname)) {
|
||||
return errors::NotFound("");
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.estimator.estimator import Estimator
|
||||
from tensorflow.python.estimator.model_fn import EstimatorSpec
|
||||
from tensorflow.python.estimator.run_config import RunConfig
|
||||
@ -28,9 +29,11 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.layers import core as core_layers
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import saved_model
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
@ -82,6 +85,17 @@ class RamFilesystemTest(test_util.TensorFlowTestCase):
|
||||
matches = ['ram://c/b/%d.txt' % i for i in range(10)]
|
||||
self.assertEqual(gfile.Glob('ram://c/b/*'), matches)
|
||||
|
||||
def test_file_exists(self):
|
||||
with gfile.GFile('ram://exists/a/b/c.txt', 'w') as f:
|
||||
f.write('')
|
||||
self.assertTrue(gfile.Exists('ram://exists/a'))
|
||||
self.assertTrue(gfile.Exists('ram://exists/a/b'))
|
||||
self.assertTrue(gfile.Exists('ram://exists/a/b/c.txt'))
|
||||
|
||||
self.assertFalse(gfile.Exists('ram://exists/b'))
|
||||
self.assertFalse(gfile.Exists('ram://exists/a/c'))
|
||||
self.assertFalse(gfile.Exists('ram://exists/a/b/k'))
|
||||
|
||||
def test_estimator(self):
|
||||
|
||||
def model_fn(features, labels, mode, params):
|
||||
@ -114,6 +128,18 @@ class RamFilesystemTest(test_util.TensorFlowTestCase):
|
||||
estimator.train(input_fn=input_fn, steps=10)
|
||||
estimator.train(input_fn=input_fn, steps=10)
|
||||
|
||||
def test_savedmodel(self):
|
||||
class MyModule(module.Module):
|
||||
|
||||
@def_function.function(input_signature=[])
|
||||
def foo(self):
|
||||
return constant_op.constant([1])
|
||||
|
||||
saved_model.save(MyModule(), 'ram://my_module')
|
||||
|
||||
loaded = saved_model.load('ram://my_module')
|
||||
self.assertAllEqual(loaded.foo(), [1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user