Move methods that will be used in python to separate py_utils.h/cc files.
This helps to remove the conversin library dependency from the swig libray, so we can make a separate shared library for TF-TRT later. PiperOrigin-RevId: 236765704
This commit is contained in:
parent
82ff4f8367
commit
ec2dc801e8
@ -431,3 +431,13 @@ cc_library(
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_utils",
|
||||
srcs = ["utils/py_utils.cc"],
|
||||
hdrs = ["utils/py_utils.h"],
|
||||
copts = tf_copts(),
|
||||
deps = if_tensorrt([
|
||||
"@local_config_tensorrt//:tensorrt",
|
||||
]),
|
||||
)
|
||||
|
@ -65,21 +65,6 @@ namespace convert {
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
|
||||
// Returns compiled TRT version information {Maj, Min, Patch}
|
||||
std::vector<int> GetLinkedTensorRTVersion() {
|
||||
return {NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH};
|
||||
}
|
||||
|
||||
// Returns loaded TRT library version {Maj, Min, Patch}
|
||||
std::vector<int> GetLoadedTensorRTVersion() {
|
||||
int ver = getInferLibVersion();
|
||||
int ver_major = ver / 1000;
|
||||
ver = ver - ver_major * 1000;
|
||||
int ver_minor = ver / 100;
|
||||
int ver_patch = ver - ver_minor * 100;
|
||||
return {ver_major, ver_minor, ver_patch};
|
||||
}
|
||||
|
||||
TrtCandidateSelector::TrtCandidateSelector(
|
||||
const grappler::GraphProperties& graph_properties,
|
||||
TrtPrecisionMode precision_mode)
|
||||
|
@ -92,12 +92,6 @@ Status ConvertGraphDefToTensorRT(
|
||||
// Method to call from optimization pass
|
||||
Status ConvertAfterShapes(const ConversionParams& params);
|
||||
|
||||
// Return compile time TensorRT library version information.
|
||||
std::vector<int> GetLinkedTensorRTVersion();
|
||||
|
||||
// Return runtime time TensorRT library version information.
|
||||
std::vector<int> GetLoadedTensorRTVersion();
|
||||
|
||||
// Helper method for the conversion, expose for testing.
|
||||
std::pair<int, Allocator*> GetDeviceAndAllocator(const ConversionParams& params,
|
||||
const EngineInfo& engine);
|
||||
|
@ -21,19 +21,6 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
bool IsGoogleTensorRTEnabled() {
|
||||
// TODO(laigd): consider also checking if tensorrt shared libraries are
|
||||
// accessible. We can then direct users to this function to make sure they can
|
||||
// safely write code that uses tensorrt conditionally. E.g. if it does not
|
||||
// check for for tensorrt, and user mistakenly uses tensorrt, they will just
|
||||
// crash and burn.
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name) {
|
||||
switch (mode) {
|
||||
case TrtPrecisionMode::FP32:
|
||||
|
@ -33,8 +33,6 @@ struct TrtDestroyer {
|
||||
template <typename T>
|
||||
using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
|
||||
|
||||
bool IsGoogleTensorRTEnabled();
|
||||
|
||||
enum class TrtPrecisionMode { FP32, FP16, INT8 };
|
||||
|
||||
Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name);
|
||||
|
65
tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
Normal file
65
tensorflow/compiler/tf2tensorrt/utils/py_utils.cc
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h"
|
||||
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
#include "tensorrt/include/NvInfer.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
bool IsGoogleTensorRTEnabled() {
|
||||
// TODO(laigd): consider also checking if tensorrt shared libraries are
|
||||
// accessible. We can then direct users to this function to make sure they can
|
||||
// safely write code that uses tensorrt conditionally. E.g. if it does not
|
||||
// check for for tensorrt, and user mistakenly uses tensorrt, they will just
|
||||
// crash and burn.
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
void GetLinkedTensorRTVersion(int* major, int* minor, int* patch) {
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
*major = NV_TENSORRT_MAJOR;
|
||||
*minor = NV_TENSORRT_MINOR;
|
||||
*patch = NV_TENSORRT_PATCH;
|
||||
#else
|
||||
*major = 0;
|
||||
*minor = 0;
|
||||
*patch = 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
void GetLoadedTensorRTVersion(int* major, int* minor, int* patch) {
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
int ver = getInferLibVersion();
|
||||
*major = ver / 1000;
|
||||
ver = ver - *major * 1000;
|
||||
*minor = ver / 100;
|
||||
*patch = ver - *minor * 100;
|
||||
#else
|
||||
*major = 0;
|
||||
*minor = 0;
|
||||
*patch = 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
33
tensorflow/compiler/tf2tensorrt/utils/py_utils.h
Normal file
33
tensorflow/compiler/tf2tensorrt/utils/py_utils.h
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
bool IsGoogleTensorRTEnabled();
|
||||
|
||||
// Return compile time TensorRT library version information {Maj, Min, Patch}.
|
||||
void GetLinkedTensorRTVersion(int* major, int* minor, int* patch);
|
||||
|
||||
// Return runtime time TensorRT library version information {Maj, Min, Patch}.
|
||||
void GetLoadedTensorRTVersion(int* major, int* minor, int* patch);
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_
|
@ -77,6 +77,7 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/python:platform/base.i",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2tensorrt:py_utils",
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_conversion",
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
|
||||
"//third_party/python_runtime:headers",
|
||||
|
@ -40,18 +40,6 @@ PyObject* version_helper(version_struct* in) {
|
||||
return tuple;
|
||||
}
|
||||
|
||||
/* Define converters for vector<int> */
|
||||
template<>
|
||||
bool _PyObjAs(PyObject *pyobj, int* dest) {
|
||||
*dest = PyLong_AsLong(pyobj);
|
||||
return true;
|
||||
}
|
||||
|
||||
template<>
|
||||
PyObject *_PyObjFrom(const int& src) {
|
||||
return PyLong_FromLong(src);
|
||||
}
|
||||
|
||||
%}
|
||||
|
||||
_LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
|
||||
@ -63,12 +51,10 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
|
||||
}
|
||||
|
||||
%{
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h"
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
%unignore tensorflow;
|
||||
%unignore get_linked_tensorrt_version;
|
||||
%unignore get_loaded_tensorrt_version;
|
||||
%unignore is_tensorrt_enabled;
|
||||
@ -78,24 +64,16 @@ _LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
|
||||
version_struct get_linked_tensorrt_version() {
|
||||
// Return the version at the link time.
|
||||
version_struct s;
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion();
|
||||
s.vmajor = lv[0];
|
||||
s.vminor = lv[1];
|
||||
s.vpatch = lv[2];
|
||||
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
tensorflow::tensorrt::GetLinkedTensorRTVersion(
|
||||
&s.vmajor, &s.vminor, &s.vpatch);
|
||||
return s;
|
||||
}
|
||||
|
||||
version_struct get_loaded_tensorrt_version() {
|
||||
// Return the version from the loaded library.
|
||||
version_struct s;
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion();
|
||||
s.vmajor = lv[0];
|
||||
s.vminor = lv[1];
|
||||
s.vpatch = lv[2];
|
||||
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
tensorflow::tensorrt::GetLoadedTensorRTVersion(
|
||||
&s.vmajor, &s.vminor, &s.vpatch);
|
||||
return s;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user