Link TF-TRT into python:no_contrib (so it's included in TF 2.0 build) and lazily load TRT shared library.
PiperOrigin-RevId: 233494604
This commit is contained in:
parent
5a297f5efe
commit
8bb8258646
@ -160,6 +160,7 @@ tf_custom_op_py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resources",
|
||||
|
@ -18,17 +18,45 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
import platform
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
|
||||
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import *
|
||||
_trt_ops_so = None
|
||||
_module_lock = threading.Lock()
|
||||
|
||||
from tensorflow.python.framework import load_library
|
||||
from tensorflow.python.platform import resource_loader
|
||||
# pylint: enable=wildcard-import,unused-import,g-import-not-at-top
|
||||
|
||||
_trt_ops = load_library.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_trt_ops.so"))
|
||||
else:
|
||||
raise RuntimeError("Windows platforms are not supported")
|
||||
def load_trt_ops():
|
||||
"""Load TF-TRT op libraries so if it hasn't been loaded already."""
|
||||
global _trt_ops_so
|
||||
|
||||
if platform.system() == "Windows":
|
||||
raise RuntimeError("Windows platforms are not supported")
|
||||
|
||||
with _module_lock:
|
||||
if _trt_ops_so:
|
||||
return
|
||||
|
||||
# TODO(laigd): we should load TF-TRT kernels here as well after removing the
|
||||
# swig binding.
|
||||
try:
|
||||
# TODO(lagid): consider getting rid of these unused imports.
|
||||
# pylint: disable=unused-import,g-import-not-at-top,unused-variable
|
||||
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import get_serialized_resource_op
|
||||
from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import trt_engine_op
|
||||
from tensorflow.python.framework import load_library
|
||||
from tensorflow.python.platform import resource_loader
|
||||
# pylint: enable=unused-import,g-import-not-at-top,unused-variable
|
||||
|
||||
_trt_ops_so = load_library.load_op_library(
|
||||
resource_loader.get_path_to_datafile("_trt_ops.so"))
|
||||
except errors.NotFoundError as e:
|
||||
no_trt_message = (
|
||||
"**** Failed to initialize TensorRT. This is either because the "
|
||||
"TensorRT installation path is not in LD_LIBRARY_PATH, or because "
|
||||
"you do not have it installed. If not installed, please go to "
|
||||
"https://developer.nvidia.com/tensorrt to download and install "
|
||||
"TensorRT ****")
|
||||
print(no_trt_message)
|
||||
raise e
|
||||
|
@ -173,6 +173,7 @@ py_library(
|
||||
"//tensorflow/lite/python:lite",
|
||||
"//tensorflow/python/compat",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/compiler",
|
||||
"//tensorflow/python/data",
|
||||
"//tensorflow/python/distribute",
|
||||
"//tensorflow/python/distribute:estimator_training",
|
||||
|
17
tensorflow/python/compiler/BUILD
Normal file
17
tensorflow/python/compiler/BUILD
Normal file
@ -0,0 +1,17 @@
|
||||
# Description:
|
||||
# Python APIs for various Tensorflow backends.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "compiler",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python/compiler/tensorrt:init_py",
|
||||
],
|
||||
)
|
0
tensorflow/python/compiler/__init__.py
Normal file
0
tensorflow/python/compiler/__init__.py
Normal file
@ -32,8 +32,6 @@ py_library(
|
||||
deps = [
|
||||
":tf_trt_integration_test_base",
|
||||
":trt_convert_py",
|
||||
":trt_ops_py",
|
||||
"//tensorflow/python:errors",
|
||||
],
|
||||
)
|
||||
|
||||
@ -51,6 +49,7 @@ py_library(
|
||||
srcs = ["trt_convert.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":trt_ops_py",
|
||||
":wrap_conversion",
|
||||
"//tensorflow/python:graph_util",
|
||||
"//tensorflow/python:session",
|
||||
@ -83,7 +82,6 @@ py_library(
|
||||
srcs = ["test/tf_trt_integration_test_base.py"],
|
||||
deps = [
|
||||
":trt_convert_py",
|
||||
":trt_ops_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
|
@ -18,25 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
# pylint: disable=unused-import,g-import-not-at-top,line-too-long
|
||||
try:
|
||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
||||
from tensorflow.python.compiler.tensorrt.trt_convert import add_test_value
|
||||
from tensorflow.python.compiler.tensorrt.trt_convert import calib_graph_to_infer_graph
|
||||
from tensorflow.python.compiler.tensorrt.trt_convert import clear_test_values
|
||||
from tensorflow.python.compiler.tensorrt.trt_convert import create_inference_graph
|
||||
from tensorflow.python.compiler.tensorrt.trt_convert import enable_test_value
|
||||
from tensorflow.python.compiler.tensorrt.trt_convert import get_test_value
|
||||
from tensorflow.python.compiler.tensorrt.trt_convert import is_tensorrt_enabled
|
||||
except errors.NotFoundError as e:
|
||||
no_trt_message = (
|
||||
'**** Failed to initialize TensorRT. This is either because the TensorRT'
|
||||
' installation path is not in LD_LIBRARY_PATH, or because you do not have'
|
||||
' it installed. If not installed, please go to'
|
||||
' https://developer.nvidia.com/tensorrt to download and install'
|
||||
' TensorRT ****')
|
||||
print(no_trt_message)
|
||||
raise e
|
||||
# pylint: enable=unused-import,g-import-not-at-top,line-too-long
|
||||
# pylint: disable=unused-import,line-too-long
|
||||
from tensorflow.python.compiler.tensorrt.trt_convert import create_inference_graph
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
|
@ -20,8 +20,9 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -154,7 +155,7 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
|
||||
"""Setup method."""
|
||||
super(PartiallyConvertedTestA, self).setUp()
|
||||
# Let it fail to build the second engine.
|
||||
trt_convert.add_test_value("TRTEngineOp_1:CreateTRTNode", "fail")
|
||||
add_test_value("TRTEngineOp_1:CreateTRTNode", "fail")
|
||||
|
||||
def GetParams(self):
|
||||
"""Create a graph containing two segment."""
|
||||
@ -209,8 +210,8 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA):
|
||||
"""Setup method."""
|
||||
super(PartiallyConvertedTestB, self).setUp()
|
||||
# Let it fail to build the first engine.
|
||||
trt_convert.clear_test_values("")
|
||||
trt_convert.add_test_value("TRTEngineOp_0:CreateTRTNode", "fail")
|
||||
clear_test_values("")
|
||||
add_test_value("TRTEngineOp_0:CreateTRTNode", "fail")
|
||||
|
||||
def ExpectedEnginesToBuild(self, run_params):
|
||||
"""Return the expected engines to build."""
|
||||
|
@ -18,13 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
||||
# pylint: enable=unused-import
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import data
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
|
||||
from tensorflow.python.estimator.estimator import Estimator
|
||||
from tensorflow.python.estimator.model_fn import EstimatorSpec
|
||||
from tensorflow.python.estimator.model_fn import ModeKeys
|
||||
@ -263,7 +262,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
|
||||
# num_epochs=100,
|
||||
# model_dir=model_dir)
|
||||
def testEval(self):
|
||||
if not trt_convert.is_tensorrt_enabled():
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
model_dir = test.test_src_dir_path('python/compiler/tensorrt/test/testdata')
|
||||
|
||||
@ -276,7 +275,7 @@ class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
|
||||
logging.info('accuracy_tf_native: %f', accuracy_tf_native)
|
||||
self.assertAllClose(0.9662, accuracy_tf_native, rtol=1e-3, atol=1e-3)
|
||||
|
||||
if trt_convert.get_linked_tensorrt_version()[0] < 5:
|
||||
if get_linked_tensorrt_version()[0] < 5:
|
||||
return
|
||||
|
||||
accuracy_tf_trt = self._Run(
|
||||
|
@ -20,8 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||
from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -72,7 +72,7 @@ class QuantizationMissingAllRangesTest(trt_test.TfTrtIntegrationTestBase):
|
||||
return _GetParams(add_quantization_nodes=False)
|
||||
|
||||
def ShouldRunTest(self, run_params):
|
||||
if trt_convert.get_linked_tensorrt_version()[0] < 5:
|
||||
if get_linked_tensorrt_version()[0] < 5:
|
||||
return False
|
||||
# Only test static engine mode, with or without calibration.
|
||||
return (trt_test.IsQuantizationMode(run_params.precision_mode) and
|
||||
@ -96,7 +96,7 @@ class QuantizationWithRangesTest(trt_test.TfTrtIntegrationTestBase):
|
||||
return _GetParams(add_quantization_nodes=True)
|
||||
|
||||
def ShouldRunTest(self, run_params):
|
||||
if trt_convert.get_linked_tensorrt_version()[0] < 5:
|
||||
if get_linked_tensorrt_version()[0] < 5:
|
||||
return False
|
||||
# Test static/dynamic engine with/without calibration.
|
||||
return (trt_test.IsQuantizationMode(run_params.precision_mode) and
|
||||
|
@ -25,12 +25,13 @@ import warnings
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
||||
# pylint: enable=unused-import
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.compiler.tensorrt import trt_convert
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import graph_io
|
||||
from tensorflow.python.framework import importer
|
||||
@ -151,7 +152,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
def setUpClass(cls):
|
||||
"""Setup method for the module."""
|
||||
super(TfTrtIntegrationTestBase, cls).setUpClass()
|
||||
trt_convert.enable_test_value()
|
||||
enable_test_value()
|
||||
|
||||
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
|
||||
super(TfTrtIntegrationTestBase, self).__init__(methodName)
|
||||
@ -161,7 +162,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
"""Setup method."""
|
||||
super(TfTrtIntegrationTestBase, self).setUp()
|
||||
warnings.simplefilter("always")
|
||||
trt_convert.clear_test_values("")
|
||||
clear_test_values("")
|
||||
|
||||
def GetParams(self):
|
||||
"""Return a TfTrtIntegrationTestParams for test, implemented by subclass."""
|
||||
@ -246,9 +247,9 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
def _PrepareRun(self, graph_state):
|
||||
"""Set up necessary testing environment before calling sess.run()."""
|
||||
# Clear test values added by TRTEngineOp.
|
||||
trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine")
|
||||
trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteCalibration")
|
||||
trt_convert.clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment")
|
||||
clear_test_values("TRTEngineOp_.*:ExecuteTrtEngine")
|
||||
clear_test_values("TRTEngineOp_.*:ExecuteCalibration")
|
||||
clear_test_values("TRTEngineOp_.*:ExecuteNativeSegment")
|
||||
|
||||
def _GetGPUOptions(self):
|
||||
gpu_options = config_pb2.GPUOptions()
|
||||
@ -282,7 +283,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
|
||||
def _ExpectTestValue(self, engine_name, method, expected_value):
|
||||
label = "%s:%s" % (engine_name, method)
|
||||
actual_value = trt_convert.get_test_value(label)
|
||||
actual_value = get_test_value(label)
|
||||
self.assertEqual(
|
||||
expected_value,
|
||||
actual_value,
|
||||
@ -639,5 +640,5 @@ def _AddTests(test_class):
|
||||
setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params))
|
||||
|
||||
|
||||
if trt_convert.is_tensorrt_enabled():
|
||||
if is_tensorrt_enabled():
|
||||
_AddTests(TfTrtIntegrationTestBase)
|
||||
|
@ -19,17 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six as _six
|
||||
# pylint: disable=unused-import,line-too-long
|
||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
|
||||
# pylint: enable=unused-import,line-too-long
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
@ -357,6 +347,14 @@ class TrtGraphConverter(GraphConverter):
|
||||
TypeError: if any of the parameters are of unexpected type.
|
||||
ValueError: if any of the parameters are of unexpected value.
|
||||
"""
|
||||
# Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
|
||||
# even if it cannot find TensorRT library.
|
||||
trt_ops.load_trt_ops()
|
||||
# pylint: disable=g-import-not-at-top,unused-import,line-too-long,unused-variable
|
||||
# Import a random symbol to trigger loading of TRT library.
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
|
||||
# pylint: enable=g-import-not-at-top,unused-import,line-too-long,unused-variable
|
||||
|
||||
if rewriter_config_template is not None and not isinstance(
|
||||
rewriter_config_template, rewriter_config_pb2.RewriterConfig):
|
||||
raise TypeError(
|
||||
@ -457,6 +455,14 @@ class TrtGraphConverter(GraphConverter):
|
||||
nodes_blacklist=nodes_blacklist,
|
||||
session_config=session_config)
|
||||
|
||||
# Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
|
||||
# even if it cannot find TensorRT library.
|
||||
trt_ops.load_trt_ops()
|
||||
# pylint: disable=g-import-not-at-top,line-too-long
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version
|
||||
# pylint: enable=g-import-not-at-top,line-too-long
|
||||
|
||||
# Check compatibility of TensorRT version.
|
||||
compiled_version = get_linked_tensorrt_version()
|
||||
loaded_version = get_loaded_tensorrt_version()
|
||||
@ -642,6 +648,12 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
|
||||
Raises:
|
||||
RuntimeError: if the returned status message is malformed.
|
||||
"""
|
||||
# Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
|
||||
# even if it cannot find TensorRT library.
|
||||
trt_ops.load_trt_ops()
|
||||
# pylint: disable=g-import-not-at-top,line-too-long
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
|
||||
# pylint: enable=g-import-not-at-top,line-too-long
|
||||
|
||||
is_calib_graph = False
|
||||
for n in calibration_graph_def.node:
|
||||
|
@ -20,9 +20,10 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
|
||||
# pylint: enable=unused-import
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value
|
||||
from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
@ -53,6 +54,8 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testGetTensorrtRewriterConfig(self):
|
||||
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
|
||||
rewriter_config_template=None,
|
||||
max_batch_size=128,
|
||||
@ -172,7 +175,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testCreateInferenceGraph_BasicConversion(self):
|
||||
"""Test case for trt_convert.create_inference_graph()."""
|
||||
if not trt_convert.is_tensorrt_enabled():
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
|
||||
# Use GraphDef as input.
|
||||
@ -187,25 +190,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
output_saved_model_dir)
|
||||
|
||||
def _TestRun(self, sess, batch_size, expect_engine_is_run):
|
||||
trt_convert.clear_test_values("")
|
||||
clear_test_values("")
|
||||
result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
|
||||
self.assertAllEqual([[[4.0]]] * batch_size, result)
|
||||
execute_engine_test_value = ("done" if expect_engine_is_run else "")
|
||||
execute_native_segment_test_value = ("" if expect_engine_is_run else "done")
|
||||
self.assertEqual(
|
||||
execute_engine_test_value,
|
||||
trt_convert.get_test_value("TRTEngineOp_0:ExecuteTrtEngine"))
|
||||
self.assertEqual(
|
||||
execute_native_segment_test_value,
|
||||
trt_convert.get_test_value("TRTEngineOp_0:ExecuteNativeSegment"))
|
||||
self.assertEqual(execute_engine_test_value,
|
||||
get_test_value("TRTEngineOp_0:ExecuteTrtEngine"))
|
||||
self.assertEqual(execute_native_segment_test_value,
|
||||
get_test_value("TRTEngineOp_0:ExecuteNativeSegment"))
|
||||
|
||||
def testCreateInferenceGraph_MinimumSegmentSize(self):
|
||||
if not trt_convert.is_tensorrt_enabled():
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
output_graph_def = trt_convert.create_inference_graph(
|
||||
self._GetGraphDef(), ["output"],
|
||||
minimum_segment_size=5,
|
||||
max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
|
||||
minimum_segment_size=5,
|
||||
is_dynamic_op=False)
|
||||
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
|
||||
self.assertEqual({
|
||||
@ -218,9 +219,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
}, node_name_to_op)
|
||||
|
||||
def testCreateInferenceGraph_DynamicOp(self):
|
||||
if not trt_convert.is_tensorrt_enabled():
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
trt_convert.enable_test_value()
|
||||
enable_test_value()
|
||||
|
||||
tmp_dir = self.get_temp_dir()
|
||||
input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
|
||||
@ -261,9 +262,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self._TestRun(sess, 3, True)
|
||||
|
||||
def testCreateInferenceGraph_StaticOp(self):
|
||||
if not trt_convert.is_tensorrt_enabled():
|
||||
if not is_tensorrt_enabled():
|
||||
return
|
||||
trt_convert.enable_test_value()
|
||||
enable_test_value()
|
||||
|
||||
tmp_dir = self.get_temp_dir()
|
||||
input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
|
||||
|
Loading…
Reference in New Issue
Block a user