Adding support for 'b' in mode for FileIO.

Change: 145590579
This commit is contained in:
Rohan Jain 2017-01-25 13:42:13 -08:00 committed by TensorFlower Gardener
parent 191658d54f
commit 39eca117a3
18 changed files with 60 additions and 40 deletions

View File

@ -79,7 +79,7 @@ def store_trace_info(output_file_path,
file_info = trace_info.files.add()
with gfile.Open(fpath, 'r') as f:
source = f.read().decode('utf-8')
source = f.read()
file_info.file_path = fpath
file_info.source_code = source

View File

@ -85,7 +85,7 @@ class TraceTest(test.TestCase):
trace_info = trace.TraceInfo()
with gfile.Open(self._temp_trace_json) as f:
text = f.read().decode('utf-8')
text = f.read()
json_format.Parse(text, trace_info)
return trace_info

View File

@ -69,7 +69,7 @@ class PrintModelAnalysisTest(test.TestCase):
self.assertEqual(u'_TFProfRoot (--/450 params)\n'
' DW (3x3x3x6, 162/162 params)\n'
' DW2 (2x2x6x12, 288/288 params)\n',
f.read().decode('utf-8'))
f.read())
def testSelectEverything(self):
opts = model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS
@ -96,7 +96,7 @@ class PrintModelAnalysisTest(test.TestCase):
# pylint: disable=line-too-long
self.assertEqual(
'_TFProfRoot (0/450 params, 0/10.44k flops, 0B/5.28KB, _kTFScopeParent)\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables)\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|Assign)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent)\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const)\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity)\n init (0/0 params, 0/0 flops, 0B/0B, /device:CPU:0, /device:CPU:0|NoOp)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const)\n',
f.read().decode('utf-8'))
f.read())
# pylint: enable=line-too-long

View File

@ -1210,7 +1210,7 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(1, len(ui.unwrapped_outputs))
with gfile.Open(output_path, "r") as f:
self.assertEqual(b"bar\nbar\n", f.read())
self.assertEqual("bar\nbar\n", f.read())
# Clean up output file.
gfile.Remove(output_path)

View File

@ -236,7 +236,7 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
screen_output.write_to_file(file_path)
with gfile.Open(file_path, "r") as f:
self.assertEqual(b"Roses are red\nViolets are blue\n", f.read())
self.assertEqual("Roses are red\nViolets are blue\n", f.read())
# Clean up.
gfile.Remove(file_path)

View File

@ -159,7 +159,7 @@ class CursesTest(test_util.TensorFlowTestCase):
self.assertEqual(["bar"] * 2, screen_outputs[0].lines)
with gfile.Open(output_path, "r") as f:
self.assertEqual(b"bar\nbar\n", f.read())
self.assertEqual("bar\nbar\n", f.read())
if __name__ == "__main__":

View File

@ -105,7 +105,7 @@ def _read_file(filename):
if not file_io.file_exists(filename):
raise IOError("File %s does not exist." % filename)
# First try to read it as a binary file.
file_content = file_io.read_file_to_string(filename)
file_content = file_io.FileIO(filename, "rb").read()
try:
graph_def.ParseFromString(file_content)
return graph_def
@ -114,7 +114,7 @@ def _read_file(filename):
# Next try to read it as a text file.
try:
text_format.Merge(file_content.decode("utf-8"), graph_def)
text_format.Merge(file_content, graph_def)
except text_format.ParseError as e:
raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
@ -401,7 +401,7 @@ def read_meta_graph_file(filename):
if not file_io.file_exists(filename):
raise IOError("File %s does not exist." % filename)
# First try to read it as a binary file.
file_content = file_io.read_file_to_string(filename)
file_content = file_io.FileIO(filename, "rb").read()
try:
meta_graph_def.ParseFromString(file_content)
return meta_graph_def

View File

@ -34,7 +34,7 @@ class FileIO(object):
The constructor takes the following arguments:
name: name of the file
mode: one of 'r', 'w', 'a', 'r+', 'w+', 'a+'.
mode: one of 'r', 'w', 'a', 'r+', 'w+', 'a+'. Append 'b' for bytes mode.
Can be used as an iterator to iterate over lines in the file.
@ -47,6 +47,8 @@ class FileIO(object):
self.__mode = mode
self._read_buf = None
self._writable_file = None
self._binary_mode = "b" in mode
mode = mode.replace("b", "")
if mode not in ("r", "w", "a", "r+", "w+", "a+"):
raise errors.InvalidArgumentError(
None, None, "mode is not 'r' or 'w' or 'a' or 'r+' or 'w+' or 'a+'")
@ -81,6 +83,12 @@ class FileIO(object):
self._writable_file = pywrap_tensorflow.CreateWritableFile(
compat.as_bytes(self.__name), compat.as_bytes(self.__mode), status)
def _prepare_value(self, val):
if self._binary_mode:
return compat.as_bytes(val)
else:
return compat.as_str_any(val)
def size(self):
"""Returns the size of the file."""
return stat(self.__name).length
@ -101,7 +109,8 @@ class FileIO(object):
n: Read 'n' bytes if n != -1. If n = -1, reads to end of file.
Returns:
'n' bytes of the file (or whole file) requested as a string.
'n' bytes of the file (or whole file) in bytes mode or 'n' bytes of the
string if in string (regular) mode.
"""
self._preread_check()
with errors.raise_exception_on_not_ok_status() as status:
@ -109,7 +118,8 @@ class FileIO(object):
length = self.size() - self.tell()
else:
length = n
return pywrap_tensorflow.ReadFromStream(self._read_buf, length, status)
return self._prepare_value(
pywrap_tensorflow.ReadFromStream(self._read_buf, length, status))
def seek(self, position):
"""Seeks to the position in the file."""
@ -121,7 +131,7 @@ class FileIO(object):
def readline(self):
r"""Reads the next line from the file. Leaves the '\n' at the end."""
self._preread_check()
return compat.as_str_any(self._read_buf.ReadLineAsString())
return self._prepare_value(self._read_buf.ReadLineAsString())
def readlines(self):
"""Returns all lines from the file in a list."""

View File

@ -45,14 +45,26 @@ class FileIoTest(test.TestCase):
file_io.write_string_to_file(file_path, "testing")
self.assertTrue(file_io.file_exists(file_path))
file_contents = file_io.read_file_to_string(file_path)
self.assertEqual(b"testing", file_contents)
self.assertEqual("testing", file_contents)
def testAtomicWriteStringToFile(self):
file_path = os.path.join(self._base_dir, "temp_file")
file_io.atomic_write_string_to_file(file_path, "testing")
self.assertTrue(file_io.file_exists(file_path))
file_contents = file_io.read_file_to_string(file_path)
self.assertEqual(b"testing", file_contents)
self.assertEqual("testing", file_contents)
def testReadBinaryMode(self):
file_path = os.path.join(self._base_dir, "temp_file")
file_io.write_string_to_file(file_path, "testing")
with file_io.FileIO(file_path, mode="rb") as f:
self.assertEqual(b"testing", f.read())
def testWriteBinaryMode(self):
file_path = os.path.join(self._base_dir, "temp_file")
file_io.FileIO(file_path, "wb").write("testing")
with file_io.FileIO(file_path, mode="r") as f:
self.assertEqual("testing", f.read())
def testAppend(self):
file_path = os.path.join(self._base_dir, "temp_file")
@ -64,7 +76,7 @@ class FileIoTest(test.TestCase):
f.write("a2\n")
with file_io.FileIO(file_path, mode="r") as f:
file_contents = f.read()
self.assertEqual(b"begin\na1\na2\n", file_contents)
self.assertEqual("begin\na1\na2\n", file_contents)
def testMultipleFiles(self):
file_prefix = os.path.join(self._base_dir, "temp_file")
@ -72,7 +84,7 @@ class FileIoTest(test.TestCase):
f = file_io.FileIO(file_prefix + str(i), mode="w+")
f.write("testing")
f.flush()
self.assertEquals(b"testing", f.read())
self.assertEqual("testing", f.read())
f.close()
def testMultipleWrites(self):
@ -81,7 +93,7 @@ class FileIoTest(test.TestCase):
f.write("line1\n")
f.write("line2")
file_contents = file_io.read_file_to_string(file_path)
self.assertEqual(b"line1\nline2", file_contents)
self.assertEqual("line1\nline2", file_contents)
def testFileWriteBadMode(self):
file_path = os.path.join(self._base_dir, "temp_file")
@ -143,7 +155,7 @@ class FileIoTest(test.TestCase):
file_io.copy(file_path, copy_path)
self.assertTrue(file_io.file_exists(copy_path))
f = file_io.FileIO(file_path, mode="r")
self.assertEqual(b"testing", f.read())
self.assertEqual("testing", f.read())
self.assertEqual(7, f.tell())
def testCopyOverwrite(self):
@ -153,7 +165,7 @@ class FileIoTest(test.TestCase):
file_io.FileIO(copy_path, mode="w").write("copy")
file_io.copy(file_path, copy_path, overwrite=True)
self.assertTrue(file_io.file_exists(copy_path))
self.assertEqual(b"testing", file_io.FileIO(file_path, mode="r").read())
self.assertEqual("testing", file_io.FileIO(file_path, mode="r").read())
def testCopyOverwriteFalse(self):
file_path = os.path.join(self._base_dir, "temp_file")
@ -339,10 +351,10 @@ class FileIoTest(test.TestCase):
with file_io.FileIO(file_path, mode="r+") as f:
f.write("testing1\ntesting2\ntesting3\n\ntesting5")
self.assertEqual(36, f.size())
self.assertEqual(b"testing1\n", f.read(9))
self.assertEqual(b"testing2\n", f.read(9))
self.assertEqual(b"t", f.read(1))
self.assertEqual(b"esting3\n\ntesting5", f.read())
self.assertEqual("testing1\n", f.read(9))
self.assertEqual("testing2\n", f.read(9))
self.assertEqual("t", f.read(1))
self.assertEqual("esting3\n\ntesting5", f.read())
def testTell(self):
file_path = os.path.join(self._base_dir, "temp_file")
@ -409,7 +421,7 @@ class FileIoTest(test.TestCase):
file_path = os.path.join(self._base_dir, "temp_file")
f = file_io.FileIO(file_path, mode="r+")
content = b"testing"
content = "testing"
f.write(content)
f.flush()
self.assertEqual(content, f.read(len(content) + 1))

View File

@ -40,7 +40,6 @@ class GFile(_FileIO):
"""File I/O wrappers without thread locking."""
def __init__(self, name, mode='r'):
mode = mode.replace('b', '')
super(GFile, self).__init__(name=name, mode=mode)
@ -48,7 +47,6 @@ class FastGFile(_FileIO):
"""File I/O wrappers without thread locking."""
def __init__(self, name, mode='r'):
mode = mode.replace('b', '')
super(FastGFile, self).__init__(name=name, mode=mode)

View File

@ -62,7 +62,7 @@ def _parse_saved_model(export_dir):
# Parse the SavedModel protocol buffer.
try:
file_content = file_io.read_file_to_string(path_to_pb)
file_content = file_io.FileIO(path_to_pb, "rb").read()
saved_model.ParseFromString(file_content)
return saved_model
except Exception: # pylint: disable=broad-except
@ -70,7 +70,7 @@ def _parse_saved_model(export_dir):
pass
try:
file_content = file_io.read_file_to_string(path_to_pbtxt)
file_content = file_io.FileIO(path_to_pbtxt, "rb").read()
text_format.Merge(file_content.decode("utf-8"), saved_model)
except text_format.ParseError as e:
raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))

View File

@ -93,7 +93,7 @@ def freeze_graph(input_graph,
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
text_format.Merge(f.read(), input_graph_def)
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:

View File

@ -77,7 +77,7 @@ def main(unused_args):
return -1
input_graph_def = graph_pb2.GraphDef()
with gfile.Open(FLAGS.input, "r") as f:
with gfile.Open(FLAGS.input, "rb") as f:
data = f.read()
if FLAGS.frozen_graph:
input_graph_def.ParseFromString(data)

View File

@ -63,7 +63,7 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
for proto_file in proto_files:
tf_logging.info('Loading proto file %s', proto_file)
# Load GraphDef.
file_data = gfile.GFile(proto_file).read()
file_data = gfile.GFile(proto_file, 'rb').read()
if proto_fileformat == 'rawproto':
graph_def = graph_pb2.GraphDef.FromString(file_data)
else:

View File

@ -91,7 +91,7 @@ def strip_unused_from_files(input_graph, input_binary, output_graph,
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read().decode("utf-8"), input_graph_def)
text_format.Merge(f.read(), input_graph_def)
output_graph_def = strip_unused(input_graph_def,
input_node_names.split(","),

View File

@ -823,7 +823,7 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
# many lines of errors from colossus in the logs.
if file_io.file_exists(coord_checkpoint_filename):
file_content = file_io.read_file_to_string(
coord_checkpoint_filename).decode("utf-8")
coord_checkpoint_filename)
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
if not ckpt.model_checkpoint_path:

View File

@ -69,7 +69,7 @@ def _latest_checkpoints_changed(configs, run_path_pairs):
config = ProjectorConfig()
config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
if file_io.file_exists(config_fpath):
file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
file_content = file_io.read_file_to_string(config_fpath)
text_format.Merge(file_content, config)
else:
config = configs[run_name]
@ -205,7 +205,7 @@ class ProjectorPlugin(TBPlugin):
config = ProjectorConfig()
config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
if file_io.file_exists(config_fpath):
file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
file_content = file_io.read_file_to_string(config_fpath)
text_format.Merge(file_content, config)
has_tensor_files = False
@ -415,7 +415,7 @@ class ProjectorPlugin(TBPlugin):
return Respond(request, '%s is not a file' % fpath, 'text/plain', 400)
bookmarks_json = None
with file_io.FileIO(fpath, 'r') as f:
with file_io.FileIO(fpath, 'rb') as f:
bookmarks_json = f.read()
return Respond(request, bookmarks_json, 'application/json')
@ -447,7 +447,7 @@ class ProjectorPlugin(TBPlugin):
if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
return Respond(request, '%s does not exist or is directory' % fpath,
'text/plain', 400)
f = file_io.FileIO(fpath, 'r')
f = file_io.FileIO(fpath, 'rb')
encoded_image_string = f.read()
f.close()
image_type = imghdr.what(None, encoded_image_string)

View File

@ -79,7 +79,7 @@ def gather_cpu_info():
# Gather num_cores_allowed
try:
with gfile.GFile('/proc/self/status') as fh:
with gfile.GFile('/proc/self/status', 'rb') as fh:
nc = re.search(r'(?m)^Cpus_allowed:\s*(.*)$', fh.read())
if nc: # e.g. 'ff' => 8, 'fff' => 12
cpu_info.num_cores_allowed = (