Fix missing defines in compilation flags that caused missing ops and add a simple test to validate functionality
This commit is contained in:
parent
f8b1986d67
commit
d074556997
@ -52,8 +52,6 @@ tf_custom_op_library(
|
|||||||
":trt_engine_op_kernel",
|
":trt_engine_op_kernel",
|
||||||
":trt_shape_function",
|
":trt_shape_function",
|
||||||
"//tensorflow/core:lib_proto_parsing",
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
"//tensorflow/core/kernels:bounds_check_lib",
|
|
||||||
"//tensorflow/core/kernels:ops_util_hdrs",
|
|
||||||
"@local_config_tensorrt//:nv_infer",
|
"@local_config_tensorrt//:nv_infer",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -77,8 +75,10 @@ cc_library(
|
|||||||
name = "trt_engine_op_kernel",
|
name = "trt_engine_op_kernel",
|
||||||
srcs = ["kernels/trt_engine_op.cc"],
|
srcs = ["kernels/trt_engine_op.cc"],
|
||||||
hdrs = ["kernels/trt_engine_op.h"],
|
hdrs = ["kernels/trt_engine_op.h"],
|
||||||
|
copts = tf_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
":trt_logging",
|
":trt_logging",
|
||||||
|
"//tensorflow/core:stream_executor_headers_lib",
|
||||||
"//tensorflow/core:framework_headers_lib",
|
"//tensorflow/core:framework_headers_lib",
|
||||||
"//tensorflow/core:gpu_headers_lib",
|
"//tensorflow/core:gpu_headers_lib",
|
||||||
"//tensorflow/core:lib_proto_parsing",
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
@ -212,9 +212,6 @@ cc_library(
|
|||||||
linkstatic = 1,
|
linkstatic = 1,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
# "//tensorflow/core:core_cpu",
|
|
||||||
# "//tensorflow/core:lib_proto_parsing",
|
|
||||||
# "//third_party/eigen3",
|
|
||||||
"@protobuf_archive//:protobuf_headers",
|
"@protobuf_archive//:protobuf_headers",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -19,6 +19,8 @@ limitations under the License.
|
|||||||
#include "cuda/include/cuda_runtime_api.h"
|
#include "cuda/include/cuda_runtime_api.h"
|
||||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
//#include "tensorflow/core/framework/device_base.h"
|
||||||
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
|
|||||||
53
tensorflow/contrib/tensorrt/test/test_tftrt.py
Normal file
53
tensorflow/contrib/tensorrt/test/test_tftrt.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# Script to test TF-TensorRT integration
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.contrib.tensorrt as trt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def getSimpleGraphDef():
|
||||||
|
'''
|
||||||
|
Create a simple graph and return its graph_def
|
||||||
|
'''
|
||||||
|
g=tf.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
A=tf.placeholder(dtype=tf.float32,shape=(None,24,24,2),name="input")
|
||||||
|
e=tf.constant([ [[[ 1., 0.5, 4., 6., 0.5, 1. ],
|
||||||
|
[ 1., 0.5, 1., 1., 0.5, 1. ]]] ],
|
||||||
|
name="weights",dtype=tf.float32)
|
||||||
|
conv=tf.nn.conv2d(input=A,filter=e,strides=[1,2,2,1],padding="SAME",name="conv")
|
||||||
|
b=tf.constant([ 4., 1.5, 2., 3., 5., 7. ],
|
||||||
|
name="bias",dtype=tf.float32)
|
||||||
|
t=tf.nn.bias_add(conv,b,name="biasAdd")
|
||||||
|
relu=tf.nn.relu(t,"relu")
|
||||||
|
idty=tf.identity(relu,"ID")
|
||||||
|
v=tf.nn.max_pool(idty,[1,2,2,1],[1,2,2,1],"VALID",name="max_pool")
|
||||||
|
out = tf.squeeze(v,name="output")
|
||||||
|
return g.as_graph_def()
|
||||||
|
|
||||||
|
def runGraph(gdef,dumm_inp):
|
||||||
|
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50)
|
||||||
|
tf.reset_default_graph()
|
||||||
|
g=tf.Graph()
|
||||||
|
with g.as_default():
|
||||||
|
inp,out=tf.import_graph_def(graph_def=gdef,
|
||||||
|
return_elements=["input","output"])
|
||||||
|
inp=inp.outputs[0]
|
||||||
|
out=out.outputs[0]
|
||||||
|
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options),
|
||||||
|
graph=g) as sess:
|
||||||
|
val=sess.run(out,{inp:dumm_inp})
|
||||||
|
return val
|
||||||
|
if "__main__" in __name__:
|
||||||
|
inpDims=(100,24,24,2)
|
||||||
|
dummy_input=np.random.random_sample(inpDims)
|
||||||
|
gdef=getSimpleGraphDef() #get graphdef
|
||||||
|
trt_graph=trt.CreateInferenceGraph(gdef,["output"],inpDims[0]) # get optimized graph
|
||||||
|
o1=runGraph(gdef,dummy_input)
|
||||||
|
o2=runGraph(trt_graph,dummy_input)
|
||||||
|
assert(np.array_equal(o1,o2))
|
||||||
|
|
||||||
Loading…
x
Reference in New Issue
Block a user