Do not parse OpList in load_op_library

This is not longer needed since op_def_registry and op_def_library both
both use OpRegistry::Global()

PiperOrigin-RevId: 272161121
This commit is contained in:
Sergei Lebedev 2019-10-01 01:51:08 -07:00 committed by TensorFlower Gardener
parent a40bd3f49c
commit dc67973778
3 changed files with 7 additions and 23 deletions

View File

@ -24,13 +24,9 @@ import imp
import os
import platform
import sys
import threading # pylint: disable=unused-import
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import
from tensorflow.python import pywrap_tensorflow as py_tf
from tensorflow.python.lib.io import file_io
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@ -59,15 +55,12 @@ def load_op_library(library_filename):
RuntimeError: when unable to load the library or get the python wrappers.
"""
lib_handle = py_tf.TF_LoadLibrary(library_filename)
op_list_str = py_tf.TF_GetOpList(lib_handle)
op_list = op_def_pb2.OpList()
op_list.ParseFromString(compat.as_bytes(op_list_str))
wrappers = py_tf.GetPythonWrappers(op_list_str)
# Delete the library handle to release any memory held in C
# that are no longer needed.
py_tf.TF_DeleteLibraryHandle(lib_handle)
try:
wrappers = py_tf.GetPythonWrappers(py_tf.TF_GetOpList(lib_handle))
finally:
# Delete the library handle to release any memory held in C
# that are no longer needed.
py_tf.TF_DeleteLibraryHandle(lib_handle)
# Get a unique name for the module.
module_name = hashlib.md5(wrappers).hexdigest()
@ -76,10 +69,6 @@ def load_op_library(library_filename):
module = imp.new_module(module_name)
# pylint: disable=exec-used
exec(wrappers, module.__dict__)
# Stash away the library handle for making calls into the dynamic library.
module.LIB_HANDLE = lib_handle
# OpDefs of the list of ops defined in the library.
module.OP_LIST = op_list
# Allow this to be recognized by AutoGraph.
setattr(module, '_IS_TENSORFLOW_PLUGIN', True)
sys.modules[module_name] = module

View File

@ -33,9 +33,6 @@ class AckermannTest(test.TestCase):
'ackermann_op.so')
ackermann = load_library.load_op_library(library_filename)
self.assertEqual(len(ackermann.OP_LIST.op), 1)
self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann')
with self.cached_session():
self.assertEqual(ackermann.ackermann().eval(), b'A(m, 0) == A(m-1, 1)')

View File

@ -32,9 +32,7 @@ class DuplicateOpTest(test.TestCase):
def testBasic(self):
library_filename = os.path.join(resource_loader.get_data_files_path(),
'duplicate_op.so')
duplicate = load_library.load_op_library(library_filename)
self.assertEqual(len(duplicate.OP_LIST.op), 0)
load_library.load_op_library(library_filename)
with self.cached_session():
self.assertEqual(math_ops.add(1, 41).eval(), 42)