From eef787ed58abaa207745d21ae1a915e23af327f3 Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Wed, 8 Aug 2018 23:39:04 -0700 Subject: [PATCH 1/6] Try to find an allocator when the engine is not assigned a device. --- tensorflow/contrib/tensorrt/BUILD | 1 + .../contrib/tensorrt/convert/convert_graph.cc | 72 ++++++++++++------- .../core/grappler/clusters/single_machine.cc | 1 + 3 files changed, 50 insertions(+), 24 deletions(-) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index fc0d22d112e..03404c1bf3a 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -280,6 +280,7 @@ tf_cuda_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core:framework_lite", + "//tensorflow/core:gpu_runtime", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 21ec8b0b30c..3dae0ea4e39 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -31,6 +31,9 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" #include "tensorflow/contrib/tensorrt/test/utils.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" +#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -772,33 +775,54 @@ std::pair GetDeviceAndAllocator( const ConversionParams& params, const EngineInfo& engine) { int cuda_device_id = -1; tensorflow::Allocator* dev_allocator = nullptr; - if (params.cluster) { - std::vector devices; - if (!engine.device.empty() && params.cluster->GetDeviceSet()) { - DeviceNameUtils::ParsedName parsed_name; - if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) && - parsed_name.has_id) { - params.cluster->GetDeviceSet()->FindMatchingDevices(parsed_name, - &devices); + if (params.cluster == nullptr || params.cluster->GetDeviceSet() == nullptr || + engine.device.empty()) { + // If device is not set, use the first found GPU device for the conversion. + for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) { + TfGpuId tf_gpu_id(tf_gpu_id_value); + CudaGpuId cuda_gpu_id; + Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + if (s.ok()) { + VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device " + << cuda_gpu_id.value(); + cuda_device_id = cuda_gpu_id.value(); + GPUOptions gpu_options; + // If the TF to Cuda gpu id mapping exist, the device and corresponding + // allocator must have been initialized already, so the + // GetGPUAllocator() call won't create a new allocator. + dev_allocator = GPUProcessState::singleton()->GetGPUAllocator( + gpu_options, tf_gpu_id, 1); + break; } + VLOG(2) << "TF GPU with id " << tf_gpu_id_value << " do not exist " << s; } - if (!devices.empty()) { - if (devices.size() > 1) { - string msg = "Found multiple matching devices using name '"; - StrAppend(&msg, engine.device, "': "); - for (auto d : devices) StrAppend(&msg, d->name(), ", "); - StrAppend(&msg, ". Will get the allocator from first one."); - LOG(WARNING) << msg; - } - tensorflow::AllocatorAttributes alloc_attr; - cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id; - dev_allocator = devices[0]->GetAllocator(alloc_attr); - VLOG(1) << "Using allocator " << dev_allocator->Name() - << " and cuda_device_id " << cuda_device_id; - } else { - LOG(WARNING) << "Cluster is set but device '" << engine.device - << "' is not found in the cluster"; + return std::make_pair(cuda_device_id, dev_allocator); + } + + // Use the device requested by the engine. + auto device_set = params.cluster->GetDeviceSet(); + std::vector devices; + DeviceNameUtils::ParsedName parsed_name; + if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name) && + parsed_name.has_id) { + device_set->FindMatchingDevices(parsed_name, &devices); + } + if (!devices.empty()) { + if (devices.size() > 1) { + string msg = "Found multiple matching devices using name '"; + StrAppend(&msg, engine.device, "': "); + for (auto d : devices) StrAppend(&msg, d->name(), ", "); + StrAppend(&msg, ". Will get the allocator from first one."); + LOG(WARNING) << msg; } + tensorflow::AllocatorAttributes alloc_attr; + cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id; + dev_allocator = devices[0]->GetAllocator(alloc_attr); + VLOG(1) << "Using allocator " << dev_allocator->Name() + << " and cuda_device_id " << cuda_device_id; + } else { + LOG(WARNING) << "Cluster is set but device '" << engine.device + << "' is not found in the cluster"; } return std::make_pair(cuda_device_id, dev_allocator); } diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index b97603c890b..83fde4fe379 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -48,6 +48,7 @@ SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus) (*options_.config.mutable_device_count())["CPU"] = 1; if (num_gpus > 0) { (*options_.config.mutable_device_count())["GPU"] = num_gpus; + options_.config.mutable_gpu_options()->set_allow_growth(true); } CHECK_GE(num_cpu_cores, 1); options_.config.set_intra_op_parallelism_threads(num_cpu_cores); From 7baf484688b950e74d7b75caed8f3b4cd06b4fcf Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 9 Aug 2018 10:51:37 -0700 Subject: [PATCH 2/6] Add test to reproduce the error. --- tensorflow/contrib/tensorrt/BUILD | 1 + .../test/no_device_assignment_test.py | 72 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 tensorflow/contrib/tensorrt/test/no_device_assignment_test.py diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 03404c1bf3a..cb60dcbb0c9 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -399,6 +399,7 @@ cuda_py_tests( # "test/vgg_block_nchw_test.py", # "test/vgg_block_test.py", "test/memory_alignment_test.py", + "test/no_device_assignment_test.py", ], additional_deps = [ ":tf_trt_integration_test_base", diff --git a/tensorflow/contrib/tensorrt/test/no_device_assignment_test.py b/tensorflow/contrib/tensorrt/test/no_device_assignment_test.py new file mode 100644 index 00000000000..a06a4228604 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/no_device_assignment_test.py @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Basic tests for TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.python import trt_convert +# pylint: disable=unused-import +from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +# pylint: enable=unused-import +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import googletest + + +class NoDeviceAssignmentTest(googletest.TestCase): + + def testNoDeviceAssignment(self): + """Test that conversion should succeed when device is not specified.""" + sess = session.Session() # By default this will consume all the gpu memory. + used_bytes = 0 + for device in sess.list_devices(): + if 'GPU:0' in device.name: + used_bytes = device.memory_limit_bytes + self.assertGreater(used_bytes, 0) + + input_dims = [100, 24, 24, 2] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=input_dims, name='input') + for i in range(2): + mul = inp * inp + inp = mul + inp + array_ops.squeeze(inp, name='output') + + trt_gdef = trt_convert.create_inference_graph( + input_graph_def=g.as_graph_def(), + outputs=['output'], + max_batch_size=input_dims[0], + # Use half of the allocated memory. It will fail if the converter + # fallback to use native cudaMalloc(), so here it tests that converter + # doesn't fallback. + max_workspace_size_bytes=used_bytes // 4, + minimum_segment_size=2, + is_dynamic_op=False) + self.assertEqual(1, + sum([node.op == 'TRTEngineOp' for node in trt_gdef.node])) + + +if __name__ == '__main__': + test.main() From 0483e03d5e0abf053cd8440752d96d486c9cd692 Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 9 Aug 2018 14:14:00 -0700 Subject: [PATCH 3/6] Fix broken test --- .../contrib/tensorrt/test/no_device_assignment_test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/tensorrt/test/no_device_assignment_test.py b/tensorflow/contrib/tensorrt/test/no_device_assignment_test.py index a06a4228604..1d54ff3a366 100644 --- a/tensorflow/contrib/tensorrt/test/no_device_assignment_test.py +++ b/tensorflow/contrib/tensorrt/test/no_device_assignment_test.py @@ -18,14 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.contrib.tensorrt.python import trt_convert # pylint: disable=unused-import from tensorflow.contrib.tensorrt.python.ops import trt_engine_op # pylint: enable=unused-import from tensorflow.python.client import session -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -37,6 +34,8 @@ class NoDeviceAssignmentTest(googletest.TestCase): def testNoDeviceAssignment(self): """Test that conversion should succeed when device is not specified.""" + if not trt_convert.is_tensorrt_enabled(): + return sess = session.Session() # By default this will consume all the gpu memory. used_bytes = 0 for device in sess.list_devices(): From 6b5be9a7f33462bd20bf14b0df9ca1fcb2da6bb3 Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Fri, 10 Aug 2018 22:59:18 -0700 Subject: [PATCH 4/6] Revert grappler changes, the fix in convert_graph.cc is sufficient. --- tensorflow/core/grappler/clusters/single_machine.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index 83fde4fe379..b97603c890b 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -48,7 +48,6 @@ SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus) (*options_.config.mutable_device_count())["CPU"] = 1; if (num_gpus > 0) { (*options_.config.mutable_device_count())["GPU"] = num_gpus; - options_.config.mutable_gpu_options()->set_allow_growth(true); } CHECK_GE(num_cpu_cores, 1); options_.config.set_intra_op_parallelism_threads(num_cpu_cores); From 0b31ebb51bcdd4b0675d3f2166f316d0f22266b3 Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 16 Aug 2018 11:32:10 -0700 Subject: [PATCH 5/6] Add c++ test and remove python test. --- tensorflow/contrib/tensorrt/BUILD | 24 ++- .../contrib/tensorrt/convert/convert_graph.h | 6 + .../tensorrt/convert/convert_graph_test.cc | 140 ++++++++++++++++++ .../contrib/tensorrt/convert/convert_nodes.cc | 3 + .../contrib/tensorrt/convert/convert_nodes.h | 5 +- .../test/no_device_assignment_test.py | 71 --------- 6 files changed, 175 insertions(+), 74 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/convert/convert_graph_test.cc delete mode 100644 tensorflow/contrib/tensorrt/test/no_device_assignment_test.py diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index cb60dcbb0c9..c070b71e38f 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -294,6 +294,29 @@ tf_cuda_library( ]) + tf_custom_op_library_additional_deps(), ) +tf_cuda_cc_test( + name = "convert_graph_test", + size = "medium", + srcs = ["convert/convert_graph_test.cc"], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], + deps = [ + ":trt_conversion", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core:core_cpu", + "//tensorflow/core:direct_session", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + # Library for the segmenting portion of TensorRT operation creation cc_library( name = "segment", @@ -399,7 +422,6 @@ cuda_py_tests( # "test/vgg_block_nchw_test.py", # "test/vgg_block_test.py", "test/memory_alignment_test.py", - "test/no_device_assignment_test.py", ], additional_deps = [ ":tf_trt_integration_test_base", diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 9d986e48904..35252023698 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -84,6 +85,11 @@ std::vector GetLinkedTensorRTVersion(); // Return runtime time TensorRT library version information. std::vector GetLoadedTensorRTVersion(); + +// Helper method for the conversion, expose for testing. +std::pair GetDeviceAndAllocator( + const ConversionParams& params, const EngineInfo& engine); + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc new file mode 100644 index 00000000000..0d3963514a6 --- /dev/null +++ b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/convert/convert_graph.h" + +#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" // NOLINT +#include "tensorflow/core/public/session.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +class FakeCluster : public grappler::Cluster { + public: + FakeCluster() : Cluster(0) {} + + void SetDeviceSet(const DeviceSet* device_set) { device_set_ = device_set; } + + const DeviceSet* GetDeviceSet() const override { return device_set_; } + + string type() const override { return ""; } + Status Provision() override { return Status::OK(); } + Status Initialize(const grappler::GrapplerItem& item) override { + return Status::OK(); + } + virtual Status Run(const GraphDef& graph_def, + const std::vector>& feed, + const std::vector& fetch, + RunMetadata* metadata) override { + return Status::OK(); + } + + private: + const DeviceSet* device_set_; +}; + +TEST(ConvertGraphTest, GetDeviceAndAllocator) { + ConversionParams params; + EngineInfo engine_info; + { + // params.cluster is not set, and no gpu device is available. + auto result = GetDeviceAndAllocator(params, engine_info); + EXPECT_EQ(-1, result.first); + EXPECT_EQ(nullptr, result.second); + } + + // Create a session with two (virtual) gpu device. + SessionOptions options; + ConfigProto* config = &options.config; + GPUOptions* gpu_options = config->mutable_gpu_options(); + auto virtual_devices = + gpu_options->mutable_experimental()->add_virtual_devices(); + virtual_devices->add_memory_limit_mb(200); + virtual_devices->add_memory_limit_mb(200); + std::unique_ptr session(NewSession(options)); + + { + // params.cluster is not set, should find and return first gpu id and + // corresponding allocator. + auto result = GetDeviceAndAllocator(params, engine_info); + EXPECT_EQ(0, result.first); + EXPECT_NE(nullptr, result.second); + EXPECT_EQ("GPU_0_bfc", result.second->Name()); + } + + FakeCluster cluster; + params.cluster = &cluster; + { + // params.cluster->GetDeviceSet() returns null, should find and return first + // gpu id and corresponding allocator. + auto result = GetDeviceAndAllocator(params, engine_info); + EXPECT_EQ(0, result.first); + EXPECT_NE(nullptr, result.second); + EXPECT_EQ("GPU_0_bfc", result.second->Name()); + } + + // Build the DeviceSet. + DeviceSet device_set; + const DeviceMgr* device_mgr = nullptr; + TF_ASSERT_OK(session->LocalDeviceManager(&device_mgr)); + for (auto d : device_mgr->ListDevices()) { + device_set.AddDevice(d); + } + cluster.SetDeviceSet(&device_set); + { + // engine_info.device is not set, should find and return first gpu id and + // corresponding allocator. + auto result = GetDeviceAndAllocator(params, engine_info); + EXPECT_EQ(0, result.first); + EXPECT_NE(nullptr, result.second); + EXPECT_EQ("GPU_0_bfc", result.second->Name()); + } + + engine_info.device = "/GPU:1"; + { + // Set to use second device. + auto result = GetDeviceAndAllocator(params, engine_info); + EXPECT_EQ(0, result.first); + EXPECT_NE(nullptr, result.second); + EXPECT_EQ("GPU_1_bfc", result.second->Name()); + } + + engine_info.device = "/GPU:3"; + { + // Set to use nonexistent device. + auto result = GetDeviceAndAllocator(params, engine_info); + EXPECT_EQ(-1, result.first); + EXPECT_EQ(nullptr, result.second); + } +} + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 35fa5902541..07b4efd33f3 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -77,6 +77,9 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +const char* const kInputPHName = "TensorRTInputPH_"; +const char* const kOutputPHName = "TensorRTOutputPH_"; + namespace convert { using ::tensorflow::str_util::Split; using ::tensorflow::strings::StrAppend; diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index a60253740fe..9274027e632 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -36,8 +36,9 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -static const char* kInputPHName = "TensorRTInputPH_"; -static const char* kOutputPHName = "TensorRTOutputPH_"; +extern const char* const kInputPHName; +extern const char* const kOutputPHName; + namespace convert { struct EngineConnection { diff --git a/tensorflow/contrib/tensorrt/test/no_device_assignment_test.py b/tensorflow/contrib/tensorrt/test/no_device_assignment_test.py deleted file mode 100644 index 1d54ff3a366..00000000000 --- a/tensorflow/contrib/tensorrt/test/no_device_assignment_test.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Basic tests for TF-TensorRT integration.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.tensorrt.python import trt_convert -# pylint: disable=unused-import -from tensorflow.contrib.tensorrt.python.ops import trt_engine_op -# pylint: enable=unused-import -from tensorflow.python.client import session -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test -from tensorflow.python.platform import googletest - - -class NoDeviceAssignmentTest(googletest.TestCase): - - def testNoDeviceAssignment(self): - """Test that conversion should succeed when device is not specified.""" - if not trt_convert.is_tensorrt_enabled(): - return - sess = session.Session() # By default this will consume all the gpu memory. - used_bytes = 0 - for device in sess.list_devices(): - if 'GPU:0' in device.name: - used_bytes = device.memory_limit_bytes - self.assertGreater(used_bytes, 0) - - input_dims = [100, 24, 24, 2] - g = ops.Graph() - with g.as_default(): - inp = array_ops.placeholder( - dtype=dtypes.float32, shape=input_dims, name='input') - for i in range(2): - mul = inp * inp - inp = mul + inp - array_ops.squeeze(inp, name='output') - - trt_gdef = trt_convert.create_inference_graph( - input_graph_def=g.as_graph_def(), - outputs=['output'], - max_batch_size=input_dims[0], - # Use half of the allocated memory. It will fail if the converter - # fallback to use native cudaMalloc(), so here it tests that converter - # doesn't fallback. - max_workspace_size_bytes=used_bytes // 4, - minimum_segment_size=2, - is_dynamic_op=False) - self.assertEqual(1, - sum([node.op == 'TRTEngineOp' for node in trt_gdef.node])) - - -if __name__ == '__main__': - test.main() From 4684421d9aa3e63dc943074025ffdc89df1a1980 Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Mon, 20 Aug 2018 16:04:00 -0700 Subject: [PATCH 6/6] Address review comments --- tensorflow/contrib/tensorrt/convert/convert_graph.cc | 3 ++- tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 3dae0ea4e39..b019c99882b 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -794,7 +794,8 @@ std::pair GetDeviceAndAllocator( gpu_options, tf_gpu_id, 1); break; } - VLOG(2) << "TF GPU with id " << tf_gpu_id_value << " do not exist " << s; + LOG(ERROR) << "TF GPU with id " << tf_gpu_id_value << " does not exist " + << s; } return std::make_pair(cuda_device_id, dev_allocator); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 07b4efd33f3..af52783da11 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -77,6 +77,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +// TODO(aaroey): put these constants into some class. const char* const kInputPHName = "TensorRTInputPH_"; const char* const kOutputPHName = "TensorRTOutputPH_";