From d074556997f3e8aad3a1ca2bcd723dc590777aeb Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Wed, 31 Jan 2018 18:49:26 -0800 Subject: [PATCH] Fix missing defines in compilation flags that caused missing ops and add a simple test to validate functionality --- tensorflow/contrib/tensorrt/BUILD | 7 +-- .../contrib/tensorrt/kernels/trt_engine_op.cc | 2 + .../contrib/tensorrt/test/test_tftrt.py | 53 +++++++++++++++++++ 3 files changed, 57 insertions(+), 5 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/test/test_tftrt.py diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index c0bba9f9502..6e46f95fcb0 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -52,8 +52,6 @@ tf_custom_op_library( ":trt_engine_op_kernel", ":trt_shape_function", "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core/kernels:bounds_check_lib", - "//tensorflow/core/kernels:ops_util_hdrs", "@local_config_tensorrt//:nv_infer", ], ) @@ -77,8 +75,10 @@ cc_library( name = "trt_engine_op_kernel", srcs = ["kernels/trt_engine_op.cc"], hdrs = ["kernels/trt_engine_op.h"], + copts = tf_copts(), deps = [ ":trt_logging", + "//tensorflow/core:stream_executor_headers_lib", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", @@ -212,9 +212,6 @@ cc_library( linkstatic = 1, deps = [ "//tensorflow/core:graph", - # "//tensorflow/core:core_cpu", - # "//tensorflow/core:lib_proto_parsing", - # "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], ) diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 983c67677fa..080e246458e 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -19,6 +19,8 @@ limitations under the License. #include "cuda/include/cuda_runtime_api.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/core/platform/logging.h" +//#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/platform/stream_executor.h" namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py new file mode 100644 index 00000000000..06b6f64c4a5 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -0,0 +1,53 @@ +# Script to test TF-TensorRT integration +# +# + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +import tensorflow.contrib.tensorrt as trt +import numpy as np + +def getSimpleGraphDef(): + ''' + Create a simple graph and return its graph_def + ''' + g=tf.Graph() + with g.as_default(): + A=tf.placeholder(dtype=tf.float32,shape=(None,24,24,2),name="input") + e=tf.constant([ [[[ 1., 0.5, 4., 6., 0.5, 1. ], + [ 1., 0.5, 1., 1., 0.5, 1. ]]] ], + name="weights",dtype=tf.float32) + conv=tf.nn.conv2d(input=A,filter=e,strides=[1,2,2,1],padding="SAME",name="conv") + b=tf.constant([ 4., 1.5, 2., 3., 5., 7. ], + name="bias",dtype=tf.float32) + t=tf.nn.bias_add(conv,b,name="biasAdd") + relu=tf.nn.relu(t,"relu") + idty=tf.identity(relu,"ID") + v=tf.nn.max_pool(idty,[1,2,2,1],[1,2,2,1],"VALID",name="max_pool") + out = tf.squeeze(v,name="output") + return g.as_graph_def() + +def runGraph(gdef,dumm_inp): + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50) + tf.reset_default_graph() + g=tf.Graph() + with g.as_default(): + inp,out=tf.import_graph_def(graph_def=gdef, + return_elements=["input","output"]) + inp=inp.outputs[0] + out=out.outputs[0] + with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options), + graph=g) as sess: + val=sess.run(out,{inp:dumm_inp}) + return val +if "__main__" in __name__: + inpDims=(100,24,24,2) + dummy_input=np.random.random_sample(inpDims) + gdef=getSimpleGraphDef() #get graphdef + trt_graph=trt.CreateInferenceGraph(gdef,["output"],inpDims[0]) # get optimized graph + o1=runGraph(gdef,dummy_input) + o2=runGraph(trt_graph,dummy_input) + assert(np.array_equal(o1,o2)) +