Fix py3 conversion
This commit is contained in:
parent
d999c5a327
commit
97aa1856bc
@ -1450,7 +1450,6 @@ def main():
|
||||
'more details.')
|
||||
config_info_line('mkl', 'Build with MKL support.')
|
||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||
config_info_line('tensorrt', 'Build with TensorRT support.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -22,13 +22,8 @@ from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import errors_impl as _impl
|
||||
from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
|
||||
from tensorflow.python.util import compat
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
import six as _six
|
||||
|
||||
# TODO(skama): get outputs from session when implemented as c++
|
||||
# optimization pass
|
||||
@ -48,13 +43,21 @@ def CreateInferenceGraph(input_graph_def,
|
||||
Returns:
|
||||
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
|
||||
"""
|
||||
def py2bytes(inp):
|
||||
return inp
|
||||
def py3bytes(inp):
|
||||
return inp.encode('utf-8',errors='surrogateescape')
|
||||
if _six.PY2:
|
||||
to_bytes=py2bytes
|
||||
else:
|
||||
to_bytes=py3bytes
|
||||
|
||||
out_names = []
|
||||
for i in outputs:
|
||||
if isinstance(i, ops.Tensor):
|
||||
out_names.append(i.name)
|
||||
out_names.append(to_bytes(i.name))
|
||||
else:
|
||||
out_names.append(i)
|
||||
out_names.append(to_bytes(i))
|
||||
|
||||
input_graph_def_str = input_graph_def.SerializeToString()
|
||||
|
||||
@ -63,10 +66,10 @@ def CreateInferenceGraph(input_graph_def,
|
||||
# allow us to return a status object from C++. Thus we return a
|
||||
# pair or strings where first one is encoded status and the second
|
||||
# one is the transformed graphs protobuf string.
|
||||
out = trt_convert(input_graph_def_str, outputs, max_batch_size,
|
||||
out = trt_convert(input_graph_def_str, out_names, max_batch_size,
|
||||
max_workspace_size_bytes)
|
||||
status = out[0]
|
||||
output_graph_def_string = out[1]
|
||||
output_graph_def_string = to_bytes(out[1])
|
||||
del input_graph_def_str #save some memory
|
||||
if len(status) < 2:
|
||||
raise _impl.UnknownError(None, None, status)
|
||||
|
@ -68,3 +68,4 @@ if "__main__" in __name__:
|
||||
o1 = runGraph(gdef, dummy_input)
|
||||
o2 = runGraph(trt_graph, dummy_input)
|
||||
assert (np.array_equal(o1, o2))
|
||||
print("Pass")
|
||||
|
Loading…
Reference in New Issue
Block a user