Do not parse OpList in load_op_library
This is not longer needed since op_def_registry and op_def_library both both use OpRegistry::Global() PiperOrigin-RevId: 272161121
This commit is contained in:
parent
a40bd3f49c
commit
dc67973778
@ -24,13 +24,9 @@ import imp
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import threading # pylint: disable=unused-import
|
||||
|
||||
from tensorflow.core.framework import op_def_pb2
|
||||
from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import
|
||||
from tensorflow.python import pywrap_tensorflow as py_tf
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -59,15 +55,12 @@ def load_op_library(library_filename):
|
||||
RuntimeError: when unable to load the library or get the python wrappers.
|
||||
"""
|
||||
lib_handle = py_tf.TF_LoadLibrary(library_filename)
|
||||
|
||||
op_list_str = py_tf.TF_GetOpList(lib_handle)
|
||||
op_list = op_def_pb2.OpList()
|
||||
op_list.ParseFromString(compat.as_bytes(op_list_str))
|
||||
wrappers = py_tf.GetPythonWrappers(op_list_str)
|
||||
|
||||
# Delete the library handle to release any memory held in C
|
||||
# that are no longer needed.
|
||||
py_tf.TF_DeleteLibraryHandle(lib_handle)
|
||||
try:
|
||||
wrappers = py_tf.GetPythonWrappers(py_tf.TF_GetOpList(lib_handle))
|
||||
finally:
|
||||
# Delete the library handle to release any memory held in C
|
||||
# that are no longer needed.
|
||||
py_tf.TF_DeleteLibraryHandle(lib_handle)
|
||||
|
||||
# Get a unique name for the module.
|
||||
module_name = hashlib.md5(wrappers).hexdigest()
|
||||
@ -76,10 +69,6 @@ def load_op_library(library_filename):
|
||||
module = imp.new_module(module_name)
|
||||
# pylint: disable=exec-used
|
||||
exec(wrappers, module.__dict__)
|
||||
# Stash away the library handle for making calls into the dynamic library.
|
||||
module.LIB_HANDLE = lib_handle
|
||||
# OpDefs of the list of ops defined in the library.
|
||||
module.OP_LIST = op_list
|
||||
# Allow this to be recognized by AutoGraph.
|
||||
setattr(module, '_IS_TENSORFLOW_PLUGIN', True)
|
||||
sys.modules[module_name] = module
|
||||
|
@ -33,9 +33,6 @@ class AckermannTest(test.TestCase):
|
||||
'ackermann_op.so')
|
||||
ackermann = load_library.load_op_library(library_filename)
|
||||
|
||||
self.assertEqual(len(ackermann.OP_LIST.op), 1)
|
||||
self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann')
|
||||
|
||||
with self.cached_session():
|
||||
self.assertEqual(ackermann.ackermann().eval(), b'A(m, 0) == A(m-1, 1)')
|
||||
|
||||
|
@ -32,9 +32,7 @@ class DuplicateOpTest(test.TestCase):
|
||||
def testBasic(self):
|
||||
library_filename = os.path.join(resource_loader.get_data_files_path(),
|
||||
'duplicate_op.so')
|
||||
duplicate = load_library.load_op_library(library_filename)
|
||||
|
||||
self.assertEqual(len(duplicate.OP_LIST.op), 0)
|
||||
load_library.load_op_library(library_filename)
|
||||
|
||||
with self.cached_session():
|
||||
self.assertEqual(math_ops.add(1, 41).eval(), 42)
|
||||
|
Loading…
Reference in New Issue
Block a user