Fix the unit test broken by PR #34288.
PiperOrigin-RevId: 283739627 Change-Id: If2de1fd30bd8c0e3c77d63ac2b594c2a71383ac5
This commit is contained in:
parent
fd9697d81d
commit
222977dffd
@ -654,9 +654,8 @@ class ConverterTest : public ::testing::Test {
|
||||
ConverterTest() { Reset(); }
|
||||
|
||||
void Reset() {
|
||||
builder_.reset(nvinfer1::createInferBuilder(logger_));
|
||||
converter_ =
|
||||
std::move(Converter::Create(builder_.get(), TrtPrecisionMode::FP32,
|
||||
std::move(Converter::Create(TrtPrecisionMode::FP32,
|
||||
/*use_calibration=*/false, &logger_)
|
||||
.ValueOrDie());
|
||||
weight_store_ = &converter_->weight_store_;
|
||||
@ -702,9 +701,6 @@ class ConverterTest : public ::testing::Test {
|
||||
|
||||
private:
|
||||
Logger logger_;
|
||||
// These members are ordered in a way such that the destruction order is:
|
||||
// converter_ -> builder_
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<Converter> converter_;
|
||||
@ -996,9 +992,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) {
|
||||
FakeITensor input, infer_1, infer_2, infer_3;
|
||||
FakeITensor not_infer;
|
||||
Logger logger;
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder(
|
||||
nvinfer1::createInferBuilder(logger));
|
||||
auto int8_converter = Converter::Create(builder.get(), TrtPrecisionMode::INT8,
|
||||
auto int8_converter = Converter::Create(TrtPrecisionMode::INT8,
|
||||
/*use_calibration=*/true, &logger)
|
||||
.ValueOrDie();
|
||||
int8_converter->ProvideQuantizationRange(&input, -5.0f, 5.0f);
|
||||
@ -1255,12 +1249,8 @@ class OpConverterTest : public ::testing::Test {
|
||||
engine_.reset(nullptr);
|
||||
|
||||
// Re-create them in proper order.
|
||||
builder_.reset(nvinfer1::createInferBuilder(logger_));
|
||||
builder_->setMaxWorkspaceSize(1 << 26);
|
||||
|
||||
// Reset the converter.
|
||||
converter_ =
|
||||
std::move(Converter::Create(builder_.get(), precision_mode_to_test_,
|
||||
std::move(Converter::Create(precision_mode_to_test_,
|
||||
/*use_calibration=*/false, &logger_)
|
||||
.ValueOrDie());
|
||||
|
||||
@ -1294,18 +1284,13 @@ class OpConverterTest : public ::testing::Test {
|
||||
TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
|
||||
|
||||
// Build the TRT engine.
|
||||
if (precision_mode == TrtPrecisionMode::FP16) {
|
||||
builder_->setFp16Mode(true);
|
||||
} else if (precision_mode == TrtPrecisionMode::INT8) {
|
||||
// Setting FP16 mode as well allows TRT to also consider FP16 kernels and
|
||||
// use them in situations where they are faster than INT8 or where INT8 is
|
||||
// not supported for a given layer.
|
||||
builder_->setFp16Mode(true);
|
||||
builder_->setInt8Mode(true);
|
||||
}
|
||||
ASSERT_EQ(nullptr, engine_.get());
|
||||
builder_->setMaxBatchSize(batch_size);
|
||||
TF_ASSERT_OK(converter_->BuildCudaEngine(&engine_));
|
||||
TF_ASSERT_OK(
|
||||
converter_->BuildCudaEngine(&engine_,
|
||||
/*max_batch_size=*/batch_size,
|
||||
/*max_workspace_size_bytes=*/1 << 26,
|
||||
/*allocator=*/nullptr,
|
||||
/*calibrator=*/nullptr));
|
||||
CHECK_NOTNULL(engine_.get());
|
||||
CheckDataTypeMatches(input_data);
|
||||
CheckDataTypeMatches(*output_data);
|
||||
@ -1473,7 +1458,6 @@ class OpConverterTest : public ::testing::Test {
|
||||
|
||||
private:
|
||||
Logger logger_;
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
|
||||
cudaStream_t stream_;
|
||||
// Used to create placeholders with shape and data type information. The
|
||||
|
Loading…
Reference in New Issue
Block a user