Fix missing defines in compilation flags that caused missing ops and add a simple test to validate functionality
This commit is contained in:
parent
f8b1986d67
commit
d074556997
@ -52,8 +52,6 @@ tf_custom_op_library(
|
||||
":trt_engine_op_kernel",
|
||||
":trt_shape_function",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core/kernels:bounds_check_lib",
|
||||
"//tensorflow/core/kernels:ops_util_hdrs",
|
||||
"@local_config_tensorrt//:nv_infer",
|
||||
],
|
||||
)
|
||||
@ -77,8 +75,10 @@ cc_library(
|
||||
name = "trt_engine_op_kernel",
|
||||
srcs = ["kernels/trt_engine_op.cc"],
|
||||
hdrs = ["kernels/trt_engine_op.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":trt_logging",
|
||||
"//tensorflow/core:stream_executor_headers_lib",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:gpu_headers_lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
@ -212,9 +212,6 @@ cc_library(
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//tensorflow/core:graph",
|
||||
# "//tensorflow/core:core_cpu",
|
||||
# "//tensorflow/core:lib_proto_parsing",
|
||||
# "//third_party/eigen3",
|
||||
"@protobuf_archive//:protobuf_headers",
|
||||
],
|
||||
)
|
||||
|
||||
@ -19,6 +19,8 @@ limitations under the License.
|
||||
#include "cuda/include/cuda_runtime_api.h"
|
||||
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
//#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
53
tensorflow/contrib/tensorrt/test/test_tftrt.py
Normal file
53
tensorflow/contrib/tensorrt/test/test_tftrt.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Script to test TF-TensorRT integration
|
||||
#
|
||||
#
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.tensorrt as trt
|
||||
import numpy as np
|
||||
|
||||
def getSimpleGraphDef():
|
||||
'''
|
||||
Create a simple graph and return its graph_def
|
||||
'''
|
||||
g=tf.Graph()
|
||||
with g.as_default():
|
||||
A=tf.placeholder(dtype=tf.float32,shape=(None,24,24,2),name="input")
|
||||
e=tf.constant([ [[[ 1., 0.5, 4., 6., 0.5, 1. ],
|
||||
[ 1., 0.5, 1., 1., 0.5, 1. ]]] ],
|
||||
name="weights",dtype=tf.float32)
|
||||
conv=tf.nn.conv2d(input=A,filter=e,strides=[1,2,2,1],padding="SAME",name="conv")
|
||||
b=tf.constant([ 4., 1.5, 2., 3., 5., 7. ],
|
||||
name="bias",dtype=tf.float32)
|
||||
t=tf.nn.bias_add(conv,b,name="biasAdd")
|
||||
relu=tf.nn.relu(t,"relu")
|
||||
idty=tf.identity(relu,"ID")
|
||||
v=tf.nn.max_pool(idty,[1,2,2,1],[1,2,2,1],"VALID",name="max_pool")
|
||||
out = tf.squeeze(v,name="output")
|
||||
return g.as_graph_def()
|
||||
|
||||
def runGraph(gdef,dumm_inp):
|
||||
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50)
|
||||
tf.reset_default_graph()
|
||||
g=tf.Graph()
|
||||
with g.as_default():
|
||||
inp,out=tf.import_graph_def(graph_def=gdef,
|
||||
return_elements=["input","output"])
|
||||
inp=inp.outputs[0]
|
||||
out=out.outputs[0]
|
||||
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options),
|
||||
graph=g) as sess:
|
||||
val=sess.run(out,{inp:dumm_inp})
|
||||
return val
|
||||
if "__main__" in __name__:
|
||||
inpDims=(100,24,24,2)
|
||||
dummy_input=np.random.random_sample(inpDims)
|
||||
gdef=getSimpleGraphDef() #get graphdef
|
||||
trt_graph=trt.CreateInferenceGraph(gdef,["output"],inpDims[0]) # get optimized graph
|
||||
o1=runGraph(gdef,dummy_input)
|
||||
o2=runGraph(trt_graph,dummy_input)
|
||||
assert(np.array_equal(o1,o2))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user