Fix missing defines in compilation flags that caused missing ops and add a simple test to validate functionality

This commit is contained in:
Sami Kama 2018-01-31 18:49:26 -08:00
parent f8b1986d67
commit d074556997
3 changed files with 57 additions and 5 deletions

View File

@ -52,8 +52,6 @@ tf_custom_op_library(
":trt_engine_op_kernel",
":trt_shape_function",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core/kernels:bounds_check_lib",
"//tensorflow/core/kernels:ops_util_hdrs",
"@local_config_tensorrt//:nv_infer",
],
)
@ -77,8 +75,10 @@ cc_library(
name = "trt_engine_op_kernel",
srcs = ["kernels/trt_engine_op.cc"],
hdrs = ["kernels/trt_engine_op.h"],
copts = tf_copts(),
deps = [
":trt_logging",
"//tensorflow/core:stream_executor_headers_lib",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib_proto_parsing",
@ -212,9 +212,6 @@ cc_library(
linkstatic = 1,
deps = [
"//tensorflow/core:graph",
# "//tensorflow/core:core_cpu",
# "//tensorflow/core:lib_proto_parsing",
# "//third_party/eigen3",
"@protobuf_archive//:protobuf_headers",
],
)

View File

@ -19,6 +19,8 @@ limitations under the License.
#include "cuda/include/cuda_runtime_api.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/core/platform/logging.h"
//#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
namespace tensorrt {

View File

@ -0,0 +1,53 @@
# Script to test TF-TensorRT integration
#
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
import numpy as np
def getSimpleGraphDef():
'''
Create a simple graph and return its graph_def
'''
g=tf.Graph()
with g.as_default():
A=tf.placeholder(dtype=tf.float32,shape=(None,24,24,2),name="input")
e=tf.constant([ [[[ 1., 0.5, 4., 6., 0.5, 1. ],
[ 1., 0.5, 1., 1., 0.5, 1. ]]] ],
name="weights",dtype=tf.float32)
conv=tf.nn.conv2d(input=A,filter=e,strides=[1,2,2,1],padding="SAME",name="conv")
b=tf.constant([ 4., 1.5, 2., 3., 5., 7. ],
name="bias",dtype=tf.float32)
t=tf.nn.bias_add(conv,b,name="biasAdd")
relu=tf.nn.relu(t,"relu")
idty=tf.identity(relu,"ID")
v=tf.nn.max_pool(idty,[1,2,2,1],[1,2,2,1],"VALID",name="max_pool")
out = tf.squeeze(v,name="output")
return g.as_graph_def()
def runGraph(gdef,dumm_inp):
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50)
tf.reset_default_graph()
g=tf.Graph()
with g.as_default():
inp,out=tf.import_graph_def(graph_def=gdef,
return_elements=["input","output"])
inp=inp.outputs[0]
out=out.outputs[0]
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options),
graph=g) as sess:
val=sess.run(out,{inp:dumm_inp})
return val
if "__main__" in __name__:
inpDims=(100,24,24,2)
dummy_input=np.random.random_sample(inpDims)
gdef=getSimpleGraphDef() #get graphdef
trt_graph=trt.CreateInferenceGraph(gdef,["output"],inpDims[0]) # get optimized graph
o1=runGraph(gdef,dummy_input)
o2=runGraph(trt_graph,dummy_input)
assert(np.array_equal(o1,o2))