Adding support for 'b' in mode for FileIO.
Change: 145590579
This commit is contained in:
parent
191658d54f
commit
39eca117a3
@ -79,7 +79,7 @@ def store_trace_info(output_file_path,
|
||||
file_info = trace_info.files.add()
|
||||
|
||||
with gfile.Open(fpath, 'r') as f:
|
||||
source = f.read().decode('utf-8')
|
||||
source = f.read()
|
||||
|
||||
file_info.file_path = fpath
|
||||
file_info.source_code = source
|
||||
|
@ -85,7 +85,7 @@ class TraceTest(test.TestCase):
|
||||
trace_info = trace.TraceInfo()
|
||||
|
||||
with gfile.Open(self._temp_trace_json) as f:
|
||||
text = f.read().decode('utf-8')
|
||||
text = f.read()
|
||||
json_format.Parse(text, trace_info)
|
||||
|
||||
return trace_info
|
||||
|
@ -69,7 +69,7 @@ class PrintModelAnalysisTest(test.TestCase):
|
||||
self.assertEqual(u'_TFProfRoot (--/450 params)\n'
|
||||
' DW (3x3x3x6, 162/162 params)\n'
|
||||
' DW2 (2x2x6x12, 288/288 params)\n',
|
||||
f.read().decode('utf-8'))
|
||||
f.read())
|
||||
|
||||
def testSelectEverything(self):
|
||||
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
|
||||
@ -96,7 +96,7 @@ class PrintModelAnalysisTest(test.TestCase):
|
||||
# pylint: disable=line-too-long
|
||||
self.assertEqual(
|
||||
'_TFProfRoot (0/450 params, 0/10.44k flops, 0B/5.28KB, _kTFScopeParent)\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n init (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|NoOp)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const)\n',
|
||||
f.read().decode('utf-8'))
|
||||
f.read())
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
|
@ -1210,7 +1210,7 @@ class CursesTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(1, len(ui.unwrapped_outputs))
|
||||
|
||||
with gfile.Open(output_path, "r") as f:
|
||||
self.assertEqual(b"bar\nbar\n", f.read())
|
||||
self.assertEqual("bar\nbar\n", f.read())
|
||||
|
||||
# Clean up output file.
|
||||
gfile.Remove(output_path)
|
||||
|
@ -236,7 +236,7 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
|
||||
screen_output.write_to_file(file_path)
|
||||
|
||||
with gfile.Open(file_path, "r") as f:
|
||||
self.assertEqual(b"Roses are red\nViolets are blue\n", f.read())
|
||||
self.assertEqual("Roses are red\nViolets are blue\n", f.read())
|
||||
|
||||
# Clean up.
|
||||
gfile.Remove(file_path)
|
||||
|
@ -159,7 +159,7 @@ class CursesTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(["bar"] * 2, screen_outputs[0].lines)
|
||||
|
||||
with gfile.Open(output_path, "r") as f:
|
||||
self.assertEqual(b"bar\nbar\n", f.read())
|
||||
self.assertEqual("bar\nbar\n", f.read())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -105,7 +105,7 @@ def _read_file(filename):
|
||||
if not file_io.file_exists(filename):
|
||||
raise IOError("File %s does not exist." % filename)
|
||||
# First try to read it as a binary file.
|
||||
file_content = file_io.read_file_to_string(filename)
|
||||
file_content = file_io.FileIO(filename, "rb").read()
|
||||
try:
|
||||
graph_def.ParseFromString(file_content)
|
||||
return graph_def
|
||||
@ -114,7 +114,7 @@ def _read_file(filename):
|
||||
|
||||
# Next try to read it as a text file.
|
||||
try:
|
||||
text_format.Merge(file_content.decode("utf-8"), graph_def)
|
||||
text_format.Merge(file_content, graph_def)
|
||||
except text_format.ParseError as e:
|
||||
raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
|
||||
|
||||
@ -401,7 +401,7 @@ def read_meta_graph_file(filename):
|
||||
if not file_io.file_exists(filename):
|
||||
raise IOError("File %s does not exist." % filename)
|
||||
# First try to read it as a binary file.
|
||||
file_content = file_io.read_file_to_string(filename)
|
||||
file_content = file_io.FileIO(filename, "rb").read()
|
||||
try:
|
||||
meta_graph_def.ParseFromString(file_content)
|
||||
return meta_graph_def
|
||||
|
@ -34,7 +34,7 @@ class FileIO(object):
|
||||
|
||||
The constructor takes the following arguments:
|
||||
name: name of the file
|
||||
mode: one of 'r', 'w', 'a', 'r+', 'w+', 'a+'.
|
||||
mode: one of 'r', 'w', 'a', 'r+', 'w+', 'a+'. Append 'b' for bytes mode.
|
||||
|
||||
Can be used as an iterator to iterate over lines in the file.
|
||||
|
||||
@ -47,6 +47,8 @@ class FileIO(object):
|
||||
self.__mode = mode
|
||||
self._read_buf = None
|
||||
self._writable_file = None
|
||||
self._binary_mode = "b" in mode
|
||||
mode = mode.replace("b", "")
|
||||
if mode not in ("r", "w", "a", "r+", "w+", "a+"):
|
||||
raise errors.InvalidArgumentError(
|
||||
None, None, "mode is not 'r' or 'w' or 'a' or 'r+' or 'w+' or 'a+'")
|
||||
@ -81,6 +83,12 @@ class FileIO(object):
|
||||
self._writable_file = pywrap_tensorflow.CreateWritableFile(
|
||||
compat.as_bytes(self.__name), compat.as_bytes(self.__mode), status)
|
||||
|
||||
def _prepare_value(self, val):
|
||||
if self._binary_mode:
|
||||
return compat.as_bytes(val)
|
||||
else:
|
||||
return compat.as_str_any(val)
|
||||
|
||||
def size(self):
|
||||
"""Returns the size of the file."""
|
||||
return stat(self.__name).length
|
||||
@ -101,7 +109,8 @@ class FileIO(object):
|
||||
n: Read 'n' bytes if n != -1. If n = -1, reads to end of file.
|
||||
|
||||
Returns:
|
||||
'n' bytes of the file (or whole file) requested as a string.
|
||||
'n' bytes of the file (or whole file) in bytes mode or 'n' bytes of the
|
||||
string if in string (regular) mode.
|
||||
"""
|
||||
self._preread_check()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
@ -109,7 +118,8 @@ class FileIO(object):
|
||||
length = self.size() - self.tell()
|
||||
else:
|
||||
length = n
|
||||
return pywrap_tensorflow.ReadFromStream(self._read_buf, length, status)
|
||||
return self._prepare_value(
|
||||
pywrap_tensorflow.ReadFromStream(self._read_buf, length, status))
|
||||
|
||||
def seek(self, position):
|
||||
"""Seeks to the position in the file."""
|
||||
@ -121,7 +131,7 @@ class FileIO(object):
|
||||
def readline(self):
|
||||
r"""Reads the next line from the file. Leaves the '\n' at the end."""
|
||||
self._preread_check()
|
||||
return compat.as_str_any(self._read_buf.ReadLineAsString())
|
||||
return self._prepare_value(self._read_buf.ReadLineAsString())
|
||||
|
||||
def readlines(self):
|
||||
"""Returns all lines from the file in a list."""
|
||||
|
@ -45,14 +45,26 @@ class FileIoTest(test.TestCase):
|
||||
file_io.write_string_to_file(file_path, "testing")
|
||||
self.assertTrue(file_io.file_exists(file_path))
|
||||
file_contents = file_io.read_file_to_string(file_path)
|
||||
self.assertEqual(b"testing", file_contents)
|
||||
self.assertEqual("testing", file_contents)
|
||||
|
||||
def testAtomicWriteStringToFile(self):
|
||||
file_path = os.path.join(self._base_dir, "temp_file")
|
||||
file_io.atomic_write_string_to_file(file_path, "testing")
|
||||
self.assertTrue(file_io.file_exists(file_path))
|
||||
file_contents = file_io.read_file_to_string(file_path)
|
||||
self.assertEqual(b"testing", file_contents)
|
||||
self.assertEqual("testing", file_contents)
|
||||
|
||||
def testReadBinaryMode(self):
|
||||
file_path = os.path.join(self._base_dir, "temp_file")
|
||||
file_io.write_string_to_file(file_path, "testing")
|
||||
with file_io.FileIO(file_path, mode="rb") as f:
|
||||
self.assertEqual(b"testing", f.read())
|
||||
|
||||
def testWriteBinaryMode(self):
|
||||
file_path = os.path.join(self._base_dir, "temp_file")
|
||||
file_io.FileIO(file_path, "wb").write("testing")
|
||||
with file_io.FileIO(file_path, mode="r") as f:
|
||||
self.assertEqual("testing", f.read())
|
||||
|
||||
def testAppend(self):
|
||||
file_path = os.path.join(self._base_dir, "temp_file")
|
||||
@ -64,7 +76,7 @@ class FileIoTest(test.TestCase):
|
||||
f.write("a2\n")
|
||||
with file_io.FileIO(file_path, mode="r") as f:
|
||||
file_contents = f.read()
|
||||
self.assertEqual(b"begin\na1\na2\n", file_contents)
|
||||
self.assertEqual("begin\na1\na2\n", file_contents)
|
||||
|
||||
def testMultipleFiles(self):
|
||||
file_prefix = os.path.join(self._base_dir, "temp_file")
|
||||
@ -72,7 +84,7 @@ class FileIoTest(test.TestCase):
|
||||
f = file_io.FileIO(file_prefix + str(i), mode="w+")
|
||||
f.write("testing")
|
||||
f.flush()
|
||||
self.assertEquals(b"testing", f.read())
|
||||
self.assertEqual("testing", f.read())
|
||||
f.close()
|
||||
|
||||
def testMultipleWrites(self):
|
||||
@ -81,7 +93,7 @@ class FileIoTest(test.TestCase):
|
||||
f.write("line1\n")
|
||||
f.write("line2")
|
||||
file_contents = file_io.read_file_to_string(file_path)
|
||||
self.assertEqual(b"line1\nline2", file_contents)
|
||||
self.assertEqual("line1\nline2", file_contents)
|
||||
|
||||
def testFileWriteBadMode(self):
|
||||
file_path = os.path.join(self._base_dir, "temp_file")
|
||||
@ -143,7 +155,7 @@ class FileIoTest(test.TestCase):
|
||||
file_io.copy(file_path, copy_path)
|
||||
self.assertTrue(file_io.file_exists(copy_path))
|
||||
f = file_io.FileIO(file_path, mode="r")
|
||||
self.assertEqual(b"testing", f.read())
|
||||
self.assertEqual("testing", f.read())
|
||||
self.assertEqual(7, f.tell())
|
||||
|
||||
def testCopyOverwrite(self):
|
||||
@ -153,7 +165,7 @@ class FileIoTest(test.TestCase):
|
||||
file_io.FileIO(copy_path, mode="w").write("copy")
|
||||
file_io.copy(file_path, copy_path, overwrite=True)
|
||||
self.assertTrue(file_io.file_exists(copy_path))
|
||||
self.assertEqual(b"testing", file_io.FileIO(file_path, mode="r").read())
|
||||
self.assertEqual("testing", file_io.FileIO(file_path, mode="r").read())
|
||||
|
||||
def testCopyOverwriteFalse(self):
|
||||
file_path = os.path.join(self._base_dir, "temp_file")
|
||||
@ -339,10 +351,10 @@ class FileIoTest(test.TestCase):
|
||||
with file_io.FileIO(file_path, mode="r+") as f:
|
||||
f.write("testing1\ntesting2\ntesting3\n\ntesting5")
|
||||
self.assertEqual(36, f.size())
|
||||
self.assertEqual(b"testing1\n", f.read(9))
|
||||
self.assertEqual(b"testing2\n", f.read(9))
|
||||
self.assertEqual(b"t", f.read(1))
|
||||
self.assertEqual(b"esting3\n\ntesting5", f.read())
|
||||
self.assertEqual("testing1\n", f.read(9))
|
||||
self.assertEqual("testing2\n", f.read(9))
|
||||
self.assertEqual("t", f.read(1))
|
||||
self.assertEqual("esting3\n\ntesting5", f.read())
|
||||
|
||||
def testTell(self):
|
||||
file_path = os.path.join(self._base_dir, "temp_file")
|
||||
@ -409,7 +421,7 @@ class FileIoTest(test.TestCase):
|
||||
|
||||
file_path = os.path.join(self._base_dir, "temp_file")
|
||||
f = file_io.FileIO(file_path, mode="r+")
|
||||
content = b"testing"
|
||||
content = "testing"
|
||||
f.write(content)
|
||||
f.flush()
|
||||
self.assertEqual(content, f.read(len(content) + 1))
|
||||
|
@ -40,7 +40,6 @@ class GFile(_FileIO):
|
||||
"""File I/O wrappers without thread locking."""
|
||||
|
||||
def __init__(self, name, mode='r'):
|
||||
mode = mode.replace('b', '')
|
||||
super(GFile, self).__init__(name=name, mode=mode)
|
||||
|
||||
|
||||
@ -48,7 +47,6 @@ class FastGFile(_FileIO):
|
||||
"""File I/O wrappers without thread locking."""
|
||||
|
||||
def __init__(self, name, mode='r'):
|
||||
mode = mode.replace('b', '')
|
||||
super(FastGFile, self).__init__(name=name, mode=mode)
|
||||
|
||||
|
||||
|
@ -62,7 +62,7 @@ def _parse_saved_model(export_dir):
|
||||
|
||||
# Parse the SavedModel protocol buffer.
|
||||
try:
|
||||
file_content = file_io.read_file_to_string(path_to_pb)
|
||||
file_content = file_io.FileIO(path_to_pb, "rb").read()
|
||||
saved_model.ParseFromString(file_content)
|
||||
return saved_model
|
||||
except Exception: # pylint: disable=broad-except
|
||||
@ -70,7 +70,7 @@ def _parse_saved_model(export_dir):
|
||||
pass
|
||||
|
||||
try:
|
||||
file_content = file_io.read_file_to_string(path_to_pbtxt)
|
||||
file_content = file_io.FileIO(path_to_pbtxt, "rb").read()
|
||||
text_format.Merge(file_content.decode("utf-8"), saved_model)
|
||||
except text_format.ParseError as e:
|
||||
raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))
|
||||
|
@ -93,7 +93,7 @@ def freeze_graph(input_graph,
|
||||
if input_binary:
|
||||
input_graph_def.ParseFromString(f.read())
|
||||
else:
|
||||
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
|
||||
text_format.Merge(f.read(), input_graph_def)
|
||||
# Remove all the explicit device specifications for this node. This helps to
|
||||
# make the graph more portable.
|
||||
if clear_devices:
|
||||
|
@ -77,7 +77,7 @@ def main(unused_args):
|
||||
return -1
|
||||
|
||||
input_graph_def = graph_pb2.GraphDef()
|
||||
with gfile.Open(FLAGS.input, "r") as f:
|
||||
with gfile.Open(FLAGS.input, "rb") as f:
|
||||
data = f.read()
|
||||
if FLAGS.frozen_graph:
|
||||
input_graph_def.ParseFromString(data)
|
||||
|
@ -63,7 +63,7 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
|
||||
for proto_file in proto_files:
|
||||
tf_logging.info('Loading proto file %s', proto_file)
|
||||
# Load GraphDef.
|
||||
file_data = gfile.GFile(proto_file).read()
|
||||
file_data = gfile.GFile(proto_file, 'rb').read()
|
||||
if proto_fileformat == 'rawproto':
|
||||
graph_def = graph_pb2.GraphDef.FromString(file_data)
|
||||
else:
|
||||
|
@ -91,7 +91,7 @@ def strip_unused_from_files(input_graph, input_binary, output_graph,
|
||||
if input_binary:
|
||||
input_graph_def.ParseFromString(f.read())
|
||||
else:
|
||||
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
|
||||
text_format.Merge(f.read(), input_graph_def)
|
||||
|
||||
output_graph_def = strip_unused(input_graph_def,
|
||||
input_node_names.split(","),
|
||||
|
@ -823,7 +823,7 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
|
||||
# many lines of errors from colossus in the logs.
|
||||
if file_io.file_exists(coord_checkpoint_filename):
|
||||
file_content = file_io.read_file_to_string(
|
||||
coord_checkpoint_filename).decode("utf-8")
|
||||
coord_checkpoint_filename)
|
||||
ckpt = CheckpointState()
|
||||
text_format.Merge(file_content, ckpt)
|
||||
if not ckpt.model_checkpoint_path:
|
||||
|
@ -69,7 +69,7 @@ def _latest_checkpoints_changed(configs, run_path_pairs):
|
||||
config = ProjectorConfig()
|
||||
config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
|
||||
if file_io.file_exists(config_fpath):
|
||||
file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
|
||||
file_content = file_io.read_file_to_string(config_fpath)
|
||||
text_format.Merge(file_content, config)
|
||||
else:
|
||||
config = configs[run_name]
|
||||
@ -205,7 +205,7 @@ class ProjectorPlugin(TBPlugin):
|
||||
config = ProjectorConfig()
|
||||
config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
|
||||
if file_io.file_exists(config_fpath):
|
||||
file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
|
||||
file_content = file_io.read_file_to_string(config_fpath)
|
||||
text_format.Merge(file_content, config)
|
||||
|
||||
has_tensor_files = False
|
||||
@ -415,7 +415,7 @@ class ProjectorPlugin(TBPlugin):
|
||||
return Respond(request, '%s is not a file' % fpath, 'text/plain', 400)
|
||||
|
||||
bookmarks_json = None
|
||||
with file_io.FileIO(fpath, 'r') as f:
|
||||
with file_io.FileIO(fpath, 'rb') as f:
|
||||
bookmarks_json = f.read()
|
||||
return Respond(request, bookmarks_json, 'application/json')
|
||||
|
||||
@ -447,7 +447,7 @@ class ProjectorPlugin(TBPlugin):
|
||||
if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
|
||||
return Respond(request, '%s does not exist or is directory' % fpath,
|
||||
'text/plain', 400)
|
||||
f = file_io.FileIO(fpath, 'r')
|
||||
f = file_io.FileIO(fpath, 'rb')
|
||||
encoded_image_string = f.read()
|
||||
f.close()
|
||||
image_type = imghdr.what(None, encoded_image_string)
|
||||
|
@ -79,7 +79,7 @@ def gather_cpu_info():
|
||||
|
||||
# Gather num_cores_allowed
|
||||
try:
|
||||
with gfile.GFile('/proc/self/status') as fh:
|
||||
with gfile.GFile('/proc/self/status', 'rb') as fh:
|
||||
nc = re.search(r'(?m)^Cpus_allowed:\s*(.*)$', fh.read())
|
||||
if nc: # e.g. 'ff' => 8, 'fff' => 12
|
||||
cpu_info.num_cores_allowed = (
|
||||
|
Loading…
Reference in New Issue
Block a user