Fix py3 conversion

This commit is contained in:
Sami Kama 2018-02-02 10:13:41 -08:00
parent d999c5a327
commit 97aa1856bc
3 changed files with 14 additions and 11 deletions

View File

@ -1450,7 +1450,6 @@ def main():
'more details.')
config_info_line('mkl', 'Build with MKL support.')
config_info_line('monolithic', 'Config for mostly static monolithic build.')
config_info_line('tensorrt', 'Build with TensorRT support.')
if __name__ == '__main__':
main()

View File

@ -22,13 +22,8 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl as _impl
from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.python.util import compat
import tensorflow as tf
from tensorflow.python.grappler import tf_optimizer
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
import six as _six
# TODO(skama): get outputs from session when implemented as c++
# optimization pass
@ -48,13 +43,21 @@ def CreateInferenceGraph(input_graph_def,
Returns:
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
"""
def py2bytes(inp):
return inp
def py3bytes(inp):
return inp.encode('utf-8',errors='surrogateescape')
if _six.PY2:
to_bytes=py2bytes
else:
to_bytes=py3bytes
out_names = []
for i in outputs:
if isinstance(i, ops.Tensor):
out_names.append(i.name)
out_names.append(to_bytes(i.name))
else:
out_names.append(i)
out_names.append(to_bytes(i))
input_graph_def_str = input_graph_def.SerializeToString()
@ -63,10 +66,10 @@ def CreateInferenceGraph(input_graph_def,
# allow us to return a status object from C++. Thus we return a
# pair or strings where first one is encoded status and the second
# one is the transformed graphs protobuf string.
out = trt_convert(input_graph_def_str, outputs, max_batch_size,
out = trt_convert(input_graph_def_str, out_names, max_batch_size,
max_workspace_size_bytes)
status = out[0]
output_graph_def_string = out[1]
output_graph_def_string = to_bytes(out[1])
del input_graph_def_str #save some memory
if len(status) < 2:
raise _impl.UnknownError(None, None, status)

View File

@ -68,3 +68,4 @@ if "__main__" in __name__:
o1 = runGraph(gdef, dummy_input)
o2 = runGraph(trt_graph, dummy_input)
assert (np.array_equal(o1, o2))
print("Pass")