Do not parse OpList in load_op_library
This is not longer needed since op_def_registry and op_def_library both both use OpRegistry::Global() PiperOrigin-RevId: 272161121
This commit is contained in:
parent
a40bd3f49c
commit
dc67973778
@ -24,13 +24,9 @@ import imp
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
import threading # pylint: disable=unused-import
|
|
||||||
|
|
||||||
from tensorflow.core.framework import op_def_pb2
|
|
||||||
from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import
|
|
||||||
from tensorflow.python import pywrap_tensorflow as py_tf
|
from tensorflow.python import pywrap_tensorflow as py_tf
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.python.util import compat
|
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -59,15 +55,12 @@ def load_op_library(library_filename):
|
|||||||
RuntimeError: when unable to load the library or get the python wrappers.
|
RuntimeError: when unable to load the library or get the python wrappers.
|
||||||
"""
|
"""
|
||||||
lib_handle = py_tf.TF_LoadLibrary(library_filename)
|
lib_handle = py_tf.TF_LoadLibrary(library_filename)
|
||||||
|
try:
|
||||||
op_list_str = py_tf.TF_GetOpList(lib_handle)
|
wrappers = py_tf.GetPythonWrappers(py_tf.TF_GetOpList(lib_handle))
|
||||||
op_list = op_def_pb2.OpList()
|
finally:
|
||||||
op_list.ParseFromString(compat.as_bytes(op_list_str))
|
# Delete the library handle to release any memory held in C
|
||||||
wrappers = py_tf.GetPythonWrappers(op_list_str)
|
# that are no longer needed.
|
||||||
|
py_tf.TF_DeleteLibraryHandle(lib_handle)
|
||||||
# Delete the library handle to release any memory held in C
|
|
||||||
# that are no longer needed.
|
|
||||||
py_tf.TF_DeleteLibraryHandle(lib_handle)
|
|
||||||
|
|
||||||
# Get a unique name for the module.
|
# Get a unique name for the module.
|
||||||
module_name = hashlib.md5(wrappers).hexdigest()
|
module_name = hashlib.md5(wrappers).hexdigest()
|
||||||
@ -76,10 +69,6 @@ def load_op_library(library_filename):
|
|||||||
module = imp.new_module(module_name)
|
module = imp.new_module(module_name)
|
||||||
# pylint: disable=exec-used
|
# pylint: disable=exec-used
|
||||||
exec(wrappers, module.__dict__)
|
exec(wrappers, module.__dict__)
|
||||||
# Stash away the library handle for making calls into the dynamic library.
|
|
||||||
module.LIB_HANDLE = lib_handle
|
|
||||||
# OpDefs of the list of ops defined in the library.
|
|
||||||
module.OP_LIST = op_list
|
|
||||||
# Allow this to be recognized by AutoGraph.
|
# Allow this to be recognized by AutoGraph.
|
||||||
setattr(module, '_IS_TENSORFLOW_PLUGIN', True)
|
setattr(module, '_IS_TENSORFLOW_PLUGIN', True)
|
||||||
sys.modules[module_name] = module
|
sys.modules[module_name] = module
|
||||||
|
@ -33,9 +33,6 @@ class AckermannTest(test.TestCase):
|
|||||||
'ackermann_op.so')
|
'ackermann_op.so')
|
||||||
ackermann = load_library.load_op_library(library_filename)
|
ackermann = load_library.load_op_library(library_filename)
|
||||||
|
|
||||||
self.assertEqual(len(ackermann.OP_LIST.op), 1)
|
|
||||||
self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann')
|
|
||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertEqual(ackermann.ackermann().eval(), b'A(m, 0) == A(m-1, 1)')
|
self.assertEqual(ackermann.ackermann().eval(), b'A(m, 0) == A(m-1, 1)')
|
||||||
|
|
||||||
|
@ -32,9 +32,7 @@ class DuplicateOpTest(test.TestCase):
|
|||||||
def testBasic(self):
|
def testBasic(self):
|
||||||
library_filename = os.path.join(resource_loader.get_data_files_path(),
|
library_filename = os.path.join(resource_loader.get_data_files_path(),
|
||||||
'duplicate_op.so')
|
'duplicate_op.so')
|
||||||
duplicate = load_library.load_op_library(library_filename)
|
load_library.load_op_library(library_filename)
|
||||||
|
|
||||||
self.assertEqual(len(duplicate.OP_LIST.op), 0)
|
|
||||||
|
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
self.assertEqual(math_ops.add(1, 41).eval(), 42)
|
self.assertEqual(math_ops.add(1, 41).eval(), 42)
|
||||||
|
Loading…
Reference in New Issue
Block a user