[TF:TRT] Handle out of GPU memory when creating TensorRT execution context.
Previously, we use ICudaEngine::createExecutionContext to create a TensorRT execution context along with the GPU needed to execute the Cuda Engine. This API doesn't handle out of GPU memory properly, instead propagates an exception. This change uses ICudaEngine::createExecutionContextWithoutDeviceMemory to create a TensorRT execution context without any GPU memory, and let TF-TRT create the needed GPU memory. In order to keep track of such GPU memory, we wrap the TensorRT execution context and the associated GPU memory in a new class callsed ExecutionContext. PiperOrigin-RevId: 351895192 Change-Id: Ie01f0241578fadba8fad25bd110f937fd47082c8
This commit is contained in:
parent
3c80be9f2c
commit
14d708ab72
@ -117,7 +117,6 @@ cc_library(
|
||||
"//tensorflow/core:stream_executor_headers_lib",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib_no_ops",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
] + if_tensorrt([
|
||||
":tensorrt_lib",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
@ -240,10 +239,13 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":trt_logging",
|
||||
":utils",
|
||||
":trt_allocator",
|
||||
":common_utils",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
] + if_tensorrt([":tensorrt_lib"]),
|
||||
)
|
||||
|
||||
@ -294,6 +296,7 @@ tf_cuda_library(
|
||||
],
|
||||
deps = [
|
||||
":trt_allocator",
|
||||
":trt_engine_utils",
|
||||
":trt_logging",
|
||||
":utils",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
|
@ -47,7 +47,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
@ -59,7 +58,6 @@ static Logger logger;
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
using ::nvinfer1::IRuntime;
|
||||
using ::stream_executor::port::StatusOr;
|
||||
|
||||
#define LOG_FIRST_FEW_WARNING_WITH_PREFIX \
|
||||
LOG_FIRST_N(WARNING, 5) << "TF-TRT Warning: "
|
||||
@ -132,6 +130,10 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
|
||||
// Returns a pair of 1) An EngineContext object that is compatible with the
|
||||
// input and 2) The index of the IExecutionContext compatible with the input.
|
||||
// If a cuda engine for the given input shapes can't be found, returns
|
||||
// (nullptr, 0) to allow native engine execution. Returns an error code for
|
||||
// any problem that would prevent both TensorRT engine exceution and native
|
||||
// segment execution.
|
||||
StatusOr<std::pair<EngineContext*, int>> GetEngine(
|
||||
const std::vector<TensorShape>& input_concrete_shapes,
|
||||
OpKernelContext* ctx, TRTEngineCacheResource* cache_resource);
|
||||
@ -914,13 +916,18 @@ StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
|
||||
for (int i = 0; i < engine_input_shapes.size(); i++) {
|
||||
engine_input_shapes[i].set_dim(0, max_batch_size);
|
||||
}
|
||||
auto exec_context_status =
|
||||
ExecutionContext::Create(raw_static_engine, allocator);
|
||||
if (!exec_context_status.ok()) {
|
||||
return std::pair<EngineContext*, int>(&empty_context, 0);
|
||||
}
|
||||
|
||||
// TODO(laigd): here we assume engine_input_shapes matches the actual input
|
||||
// shapes of the engine, we should verify that.
|
||||
cache.emplace(engine_input_shapes,
|
||||
absl::make_unique<EngineContext>(
|
||||
std::move(static_engine),
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext>(
|
||||
raw_static_engine->createExecutionContext())));
|
||||
std::move(exec_context_status.ValueOrDie())));
|
||||
// Runtime is safe to delete after engine creation
|
||||
VLOG(1) << "Size of serialized TRT engine: "
|
||||
<< serialized_segment_.capacity();
|
||||
@ -974,12 +981,12 @@ StatusOr<std::pair<EngineContext*, int>> TRTEngineOp::GetEngine(
|
||||
}
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine =
|
||||
std::move(result.ValueOrDie());
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>> exec_context;
|
||||
std::vector<ExecutionContext> exec_contexts;
|
||||
TF_RETURN_IF_ERROR(cache_res->profiles_.CreateExecutionContexts(
|
||||
engine.get(), exec_context));
|
||||
engine.get(), exec_contexts, allocator));
|
||||
cache.emplace(input_concrete_shapes,
|
||||
absl::make_unique<EngineContext>(std::move(engine),
|
||||
std::move(exec_context)));
|
||||
std::move(exec_contexts)));
|
||||
VLOG(1) << "Added new engine to cache of " << name()
|
||||
<< ". Cache size: " << cache.size();
|
||||
engine_contexts = cache.at(input_concrete_shapes).get();
|
||||
@ -1066,11 +1073,17 @@ Status TRTEngineOp::AllocateCalibrationResources(
|
||||
// dump it out during conversion for TF 2.0.
|
||||
mutex_lock lock(this->engine_mutex_);
|
||||
this->calibrator_ = std::move(cres->calibrator_);
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||
cres->engine_->createExecutionContext());
|
||||
cache_res->cache_.emplace(
|
||||
shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
|
||||
std::move(exec_context)));
|
||||
auto exec_context_status = ExecutionContext::Create(
|
||||
cres->engine_.get(), cache_res->allocator_.get());
|
||||
if (!exec_context_status.ok()) {
|
||||
LOG(ERROR) << "Calibration failed: " << s;
|
||||
cres->calibrator_->setDone(); // Ignore further pushes
|
||||
} else {
|
||||
cache_res->cache_.emplace(
|
||||
shapes, absl::make_unique<EngineContext>(
|
||||
std::move(cres->engine_),
|
||||
std::move(exec_context_status.ValueOrDie())));
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "Calibration loop terminated " << this->name();
|
||||
|
@ -139,20 +139,20 @@ class InitializeTRTResource : public OpKernel {
|
||||
engine_instance.serialized_engine().c_str(),
|
||||
engine_instance.serialized_engine().size(), nullptr));
|
||||
auto raw_engine = engine.get();
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>> ctx_vec;
|
||||
std::vector<ExecutionContext> ctx_vec;
|
||||
if (num_loaded_engine == 0) {
|
||||
// Restore profiles if there are any. Currently only 1 engine is allowed
|
||||
// in dynamic mode therefore we call this only for the 0th engine.
|
||||
// it is a no-op in implicit batch mode.
|
||||
OP_REQUIRES_OK(ctx, resource->profiles_.RestoreProfiles(raw_engine));
|
||||
OP_REQUIRES_OK(ctx, resource->profiles_.CreateExecutionContexts(
|
||||
raw_engine, ctx_vec));
|
||||
raw_engine, ctx_vec, allocator));
|
||||
} else {
|
||||
// Multiple engines are only available in static mode. For each engine
|
||||
// we have only a single execution context.
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_ctx(
|
||||
raw_engine->createExecutionContext());
|
||||
ctx_vec.push_back(std::move(exec_ctx));
|
||||
auto exec_ctx_status = ExecutionContext::Create(raw_engine, allocator);
|
||||
OP_REQUIRES_OK(ctx, exec_ctx_status.status());
|
||||
ctx_vec.push_back(std::move(exec_ctx_status.ValueOrDie()));
|
||||
}
|
||||
resource->cache_.emplace(engine_input_shapes,
|
||||
absl::make_unique<EngineContext>(
|
||||
|
@ -151,11 +151,14 @@ TEST_F(TRTEngineResourceOpsTest, Basic) {
|
||||
|
||||
// Create an engine and add it to the cache of the resource.
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine = CreateTRTEngine();
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> context(
|
||||
engine->createExecutionContext());
|
||||
auto context_status =
|
||||
ExecutionContext::Create(engine.get(), resource->allocator_.get());
|
||||
TF_ASSERT_OK(context_status.status());
|
||||
|
||||
resource->cache_.emplace(
|
||||
std::vector<TensorShape>{TensorShape({1, 1})},
|
||||
absl::make_unique<EngineContext>(std::move(engine), std::move(context)));
|
||||
absl::make_unique<EngineContext>(std::move(engine),
|
||||
std::move(context_status.ValueOrDie())));
|
||||
// Check that the resource has multiple references before it is unregistered
|
||||
// from the resource manager.
|
||||
EXPECT_FALSE(resource->RefCountIsOne());
|
||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/common/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -33,6 +35,42 @@ namespace tensorrt {
|
||||
|
||||
using absl::StrCat;
|
||||
|
||||
ExecutionContext::~ExecutionContext() {
|
||||
if (device_memory_) {
|
||||
DCHECK(memory_allocator_) << "Internal error: Device memory with address "
|
||||
<< (char*)device_memory_ << "is not freed";
|
||||
memory_allocator_->free(device_memory_);
|
||||
}
|
||||
if (execution_context_) {
|
||||
execution_context_->destroy();
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<ExecutionContext> ExecutionContext::Create(
|
||||
nvinfer1::ICudaEngine* cuda_engine, TRTBaseAllocator* allocator) {
|
||||
void* device_memory = nullptr;
|
||||
nvinfer1::IExecutionContext* execution_context;
|
||||
if (allocator == nullptr) {
|
||||
execution_context = cuda_engine->createExecutionContext();
|
||||
} else {
|
||||
execution_context =
|
||||
cuda_engine->createExecutionContextWithoutDeviceMemory();
|
||||
size_t device_memory_size = cuda_engine->getDeviceMemorySize();
|
||||
VLOG(2) << "Device memory size for cuda engine " << device_memory_size;
|
||||
|
||||
if (device_memory_size > 0) {
|
||||
device_memory = allocator->allocate(device_memory_size,
|
||||
/*unused alignment=*/0, /*flags=*/0);
|
||||
if (device_memory == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Out of GPU memory when creating execution context");
|
||||
}
|
||||
}
|
||||
execution_context->setDeviceMemory(device_memory);
|
||||
}
|
||||
return ExecutionContext(allocator, device_memory, execution_context);
|
||||
}
|
||||
|
||||
Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine,
|
||||
const nvinfer1::IExecutionContext* execution_context,
|
||||
int binding_index, bool use_implicit_batch,
|
||||
|
@ -23,12 +23,14 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
#include "third_party/tensorrt/NvInfer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
using ::stream_executor::port::StatusOr;
|
||||
|
||||
// Input/output data format for OpConverterTest::BuildAndRun().
|
||||
struct InputOutputData {
|
||||
@ -42,6 +44,81 @@ struct InputOutputData {
|
||||
Tensor tensor;
|
||||
};
|
||||
|
||||
class TRTBaseAllocator;
|
||||
|
||||
// Keeps track of the TensorRT execution context and the device memory owned by
|
||||
// the context, if any. An execution context owns the device memory that TF-TRT
|
||||
// allocates for the context. In this case, the allocator is not null and is
|
||||
// used to free the device memory. An execution context doesn't own a device
|
||||
// memory (1) if the device memory is allocated through TensorRT, or (2) the
|
||||
// device memory is allocated by TF-TRT for another execution context but
|
||||
// shared with this context. If this case, the device memory is null.
|
||||
//
|
||||
// Currently, the main reason we want to allocate the device memory for an
|
||||
// execution context in TF-TRT is because the TensorRT API to create an
|
||||
// execution context with device memory doesn't handle out of memory properly.
|
||||
//
|
||||
// To support dynamic shapes, we create multiple execution contexts for an
|
||||
// engine and may want to support multiple execution contexts sharing the same
|
||||
// device memory.
|
||||
class ExecutionContext {
|
||||
private:
|
||||
// Records the TensorRT execution context `context`, the device memory
|
||||
// `device_memory` TF-TRT allocates for the context and the device memory
|
||||
// allocator `allocator` used to allocate the memory. If TF-TRT doesn't
|
||||
// allocate any device memory for the context, then `device_memory` is null.
|
||||
// otherwise, allocator should not be null.
|
||||
ExecutionContext(TRTBaseAllocator* allocator, void* device_memory,
|
||||
nvinfer1::IExecutionContext* context)
|
||||
: memory_allocator_(allocator),
|
||||
device_memory_(device_memory),
|
||||
execution_context_(context) {}
|
||||
|
||||
public:
|
||||
// Disables copy constructors as the object owns the device memory and the
|
||||
// execution context.
|
||||
ExecutionContext(const ExecutionContext&) = delete;
|
||||
ExecutionContext& operator=(const ExecutionContext&) = delete;
|
||||
|
||||
ExecutionContext(ExecutionContext&& other)
|
||||
: memory_allocator_(other.memory_allocator_),
|
||||
device_memory_(other.device_memory_),
|
||||
execution_context_(other.execution_context_) {
|
||||
other.memory_allocator_ = nullptr;
|
||||
other.device_memory_ = nullptr;
|
||||
other.execution_context_ = nullptr;
|
||||
}
|
||||
|
||||
~ExecutionContext();
|
||||
|
||||
operator nvinfer1::IExecutionContext*() const { return execution_context_; }
|
||||
nvinfer1::IExecutionContext* GetIExecutionContext() const {
|
||||
return execution_context_;
|
||||
}
|
||||
|
||||
static StatusOr<ExecutionContext> Create(nvinfer1::ICudaEngine* cuda_engine,
|
||||
TRTBaseAllocator* allocator);
|
||||
|
||||
private:
|
||||
// The allocator used to allocate and free the device memory owned by the
|
||||
// execution context.
|
||||
TRTBaseAllocator* memory_allocator_;
|
||||
// The device memory owned by the execution context.
|
||||
void* device_memory_;
|
||||
// The TensorRT execution context.
|
||||
nvinfer1::IExecutionContext* execution_context_;
|
||||
};
|
||||
|
||||
// Creates a TensorRT execution context. If an allocator is not given, then the
|
||||
// execution context is created with device memory allocated by TensorRT.
|
||||
// Otherwise, uses the allocator to allocate the needed device memory for the
|
||||
// execution context.
|
||||
//
|
||||
// Returns an ExecutionContext object that wraps the above results. If out of
|
||||
// device memory happens, returns an error status instead.
|
||||
StatusOr<ExecutionContext> CreateExecutionContext(
|
||||
nvinfer1::ICudaEngine* cuda_engine, TRTBaseAllocator* allocator);
|
||||
|
||||
using DataVec = std::vector<InputOutputData>;
|
||||
|
||||
// Gets the binding index of a tensor in an engine.
|
||||
|
@ -89,7 +89,7 @@ string TRTEngineCacheResource::DebugString() const {
|
||||
<< "ICudaEngine: " << item.second->cuda_engine.get() << ", "
|
||||
<< "IExecutionContext: ";
|
||||
for (auto& ctx : item.second->execution_context) {
|
||||
oss << ctx.get() << ", ";
|
||||
oss << ctx.GetIExecutionContext() << ", ";
|
||||
}
|
||||
oss << dec << endl;
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
|
||||
@ -119,15 +120,13 @@ class LRUCache {
|
||||
|
||||
struct EngineContext {
|
||||
EngineContext() {} // Creates an empty context.
|
||||
EngineContext(
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine>&& input_cuda_engine,
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext>&& input_execution_context)
|
||||
EngineContext(TrtUniquePtrType<nvinfer1::ICudaEngine>&& input_cuda_engine,
|
||||
ExecutionContext&& input_execution_context)
|
||||
: cuda_engine(std::move(input_cuda_engine)) {
|
||||
execution_context.push_back(std::move(input_execution_context));
|
||||
}
|
||||
EngineContext(TrtUniquePtrType<nvinfer1::ICudaEngine>&& input_cuda_engine,
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>>&&
|
||||
input_execution_context)
|
||||
std::vector<ExecutionContext>&& input_execution_context)
|
||||
: cuda_engine(std::move(input_cuda_engine)),
|
||||
execution_context(std::move(input_execution_context)) {}
|
||||
|
||||
@ -141,7 +140,7 @@ struct EngineContext {
|
||||
", but only ", execution_context.size(),
|
||||
"contexts are present.");
|
||||
}
|
||||
*exec_ctx = execution_context[idx].get();
|
||||
*exec_ctx = execution_context[idx];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -160,8 +159,7 @@ struct EngineContext {
|
||||
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-best-practices/index.html#thread-safety
|
||||
// Additional discussion about execution context management and thread safety
|
||||
// at https://github.com/tensorflow/tensorflow/issues/36959
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>> execution_context
|
||||
TF_GUARDED_BY(mu);
|
||||
std::vector<ExecutionContext> execution_context TF_GUARDED_BY(mu);
|
||||
};
|
||||
|
||||
// Contains the context required to build the calibration data.
|
||||
|
@ -109,16 +109,17 @@ int TrtShapeOptimizationProfile::GetProfileNumber(
|
||||
}
|
||||
|
||||
Status TrtShapeOptimizationProfile::CreateExecutionContexts(
|
||||
nvinfer1::ICudaEngine* engine,
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>>& exec_context) {
|
||||
nvinfer1::ICudaEngine* engine, std::vector<ExecutionContext>& exec_context,
|
||||
TRTBaseAllocator* memory_allocator) {
|
||||
int i = 0;
|
||||
// The following loop runs once if we have static shapes, to create a single
|
||||
// execution context without profiles. In dynamic mode we create one context
|
||||
// for each profile and set the corresponding optimization profile.
|
||||
do {
|
||||
VLOG(1) << "Creating execution context " << i;
|
||||
nvinfer1::IExecutionContext* ctx = engine->createExecutionContext();
|
||||
if (ctx == nullptr) {
|
||||
auto exec_context_status =
|
||||
ExecutionContext::Create(engine, memory_allocator);
|
||||
if (!exec_context_status.ok()) {
|
||||
return errors::Internal("Failed to create execution context");
|
||||
}
|
||||
if (i > 0) {
|
||||
@ -128,14 +129,15 @@ Status TrtShapeOptimizationProfile::CreateExecutionContexts(
|
||||
// - The 0th profile is set implicitly for the first execution context
|
||||
// therefore we do not need to set.
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
bool stat = ctx->setOptimizationProfile(i);
|
||||
bool stat = exec_context_status.ValueOrDie()
|
||||
.GetIExecutionContext()
|
||||
->setOptimizationProfile(i);
|
||||
if (!stat) {
|
||||
ctx->destroy();
|
||||
return errors::Internal("Could not set TRT optimization profile.");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
exec_context.push_back(TrtUniquePtrType<nvinfer1::IExecutionContext>(ctx));
|
||||
exec_context.push_back(std::move(exec_context_status.ValueOrDie()));
|
||||
i++;
|
||||
} while (i < profiles_.size());
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -139,9 +140,9 @@ class TrtShapeOptimizationProfile {
|
||||
#endif
|
||||
|
||||
// Creates execution contexts for each optimization profile.
|
||||
Status CreateExecutionContexts(
|
||||
nvinfer1::ICudaEngine* engine,
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>>& exec_context);
|
||||
Status CreateExecutionContexts(nvinfer1::ICudaEngine* engine,
|
||||
std::vector<ExecutionContext>& exec_context,
|
||||
TRTBaseAllocator* memory_allocator);
|
||||
|
||||
// Maps input vector shapes to TRT Optimization profiles (min, max, opt) i.e.
|
||||
// maps input_shapes_ to profiles_
|
||||
|
@ -112,7 +112,7 @@ class TrtShapeOptimizationProfileTest : public ::testing::Test {
|
||||
TrtUniquePtrType<nvinfer1::IBuilderConfig> builder_config_;
|
||||
#endif
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
|
||||
std::vector<TrtUniquePtrType<nvinfer1::IExecutionContext>> exec_context_;
|
||||
std::vector<ExecutionContext> exec_context_;
|
||||
// The order is important: exec_context_ must be destroyed first, and logger
|
||||
// at last.
|
||||
#if IS_TRT_VERSION_GE(6, 0, 0, 0)
|
||||
@ -141,7 +141,8 @@ TEST_F(TrtShapeOptimizationProfileTest, Static) {
|
||||
builder_->buildCudaEngine(*network_));
|
||||
#endif
|
||||
EXPECT_NE(nullptr, engine);
|
||||
TF_CHECK_OK(profile.CreateExecutionContexts(engine.get(), exec_context_));
|
||||
TF_CHECK_OK(
|
||||
profile.CreateExecutionContexts(engine.get(), exec_context_, nullptr));
|
||||
// A single execution context should be created for a graph with static input
|
||||
ASSERT_EQ(exec_context_.size(), 1);
|
||||
EXPECT_NE(nullptr, exec_context_[0]);
|
||||
@ -178,7 +179,8 @@ TEST_F(TrtShapeOptimizationProfileTest, Dynamic) {
|
||||
builder_->buildEngineWithConfig(*network_.get(), *builder_config_.get()));
|
||||
ASSERT_NE(nullptr, engine);
|
||||
|
||||
TF_CHECK_OK(profile.CreateExecutionContexts(engine.get(), exec_context_));
|
||||
TF_CHECK_OK(
|
||||
profile.CreateExecutionContexts(engine.get(), exec_context_, nullptr));
|
||||
|
||||
// Each profile has an associated execution context.
|
||||
EXPECT_EQ(exec_context_.size(), input_profiles.size());
|
||||
@ -187,7 +189,8 @@ TEST_F(TrtShapeOptimizationProfileTest, Dynamic) {
|
||||
for (auto dimvec : input_profiles) {
|
||||
std::vector<TensorShape> shape_vec = DimVecToShapeVec(dimvec);
|
||||
int idx = profile.GetProfileNumber(shape_vec);
|
||||
int prof_idx = exec_context_[idx]->getOptimizationProfile();
|
||||
int prof_idx =
|
||||
exec_context_[idx].GetIExecutionContext()->getOptimizationProfile();
|
||||
ASSERT_GE(prof_idx, 0);
|
||||
|
||||
for (int j = 0; j < dimvec.size(); j++) {
|
||||
|
Loading…
Reference in New Issue
Block a user