Fix the unit test broken by PR #34288.
PiperOrigin-RevId: 283739627 Change-Id: If2de1fd30bd8c0e3c77d63ac2b594c2a71383ac5
This commit is contained in:
parent
fd9697d81d
commit
222977dffd
@ -654,9 +654,8 @@ class ConverterTest : public ::testing::Test {
|
|||||||
ConverterTest() { Reset(); }
|
ConverterTest() { Reset(); }
|
||||||
|
|
||||||
void Reset() {
|
void Reset() {
|
||||||
builder_.reset(nvinfer1::createInferBuilder(logger_));
|
|
||||||
converter_ =
|
converter_ =
|
||||||
std::move(Converter::Create(builder_.get(), TrtPrecisionMode::FP32,
|
std::move(Converter::Create(TrtPrecisionMode::FP32,
|
||||||
/*use_calibration=*/false, &logger_)
|
/*use_calibration=*/false, &logger_)
|
||||||
.ValueOrDie());
|
.ValueOrDie());
|
||||||
weight_store_ = &converter_->weight_store_;
|
weight_store_ = &converter_->weight_store_;
|
||||||
@ -702,9 +701,6 @@ class ConverterTest : public ::testing::Test {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Logger logger_;
|
Logger logger_;
|
||||||
// These members are ordered in a way such that the destruction order is:
|
|
||||||
// converter_ -> builder_
|
|
||||||
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::unique_ptr<Converter> converter_;
|
std::unique_ptr<Converter> converter_;
|
||||||
@ -996,9 +992,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) {
|
|||||||
FakeITensor input, infer_1, infer_2, infer_3;
|
FakeITensor input, infer_1, infer_2, infer_3;
|
||||||
FakeITensor not_infer;
|
FakeITensor not_infer;
|
||||||
Logger logger;
|
Logger logger;
|
||||||
TrtUniquePtrType<nvinfer1::IBuilder> builder(
|
auto int8_converter = Converter::Create(TrtPrecisionMode::INT8,
|
||||||
nvinfer1::createInferBuilder(logger));
|
|
||||||
auto int8_converter = Converter::Create(builder.get(), TrtPrecisionMode::INT8,
|
|
||||||
/*use_calibration=*/true, &logger)
|
/*use_calibration=*/true, &logger)
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
int8_converter->ProvideQuantizationRange(&input, -5.0f, 5.0f);
|
int8_converter->ProvideQuantizationRange(&input, -5.0f, 5.0f);
|
||||||
@ -1255,12 +1249,8 @@ class OpConverterTest : public ::testing::Test {
|
|||||||
engine_.reset(nullptr);
|
engine_.reset(nullptr);
|
||||||
|
|
||||||
// Re-create them in proper order.
|
// Re-create them in proper order.
|
||||||
builder_.reset(nvinfer1::createInferBuilder(logger_));
|
|
||||||
builder_->setMaxWorkspaceSize(1 << 26);
|
|
||||||
|
|
||||||
// Reset the converter.
|
|
||||||
converter_ =
|
converter_ =
|
||||||
std::move(Converter::Create(builder_.get(), precision_mode_to_test_,
|
std::move(Converter::Create(precision_mode_to_test_,
|
||||||
/*use_calibration=*/false, &logger_)
|
/*use_calibration=*/false, &logger_)
|
||||||
.ValueOrDie());
|
.ValueOrDie());
|
||||||
|
|
||||||
@ -1294,18 +1284,13 @@ class OpConverterTest : public ::testing::Test {
|
|||||||
TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
|
TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
|
||||||
|
|
||||||
// Build the TRT engine.
|
// Build the TRT engine.
|
||||||
if (precision_mode == TrtPrecisionMode::FP16) {
|
|
||||||
builder_->setFp16Mode(true);
|
|
||||||
} else if (precision_mode == TrtPrecisionMode::INT8) {
|
|
||||||
// Setting FP16 mode as well allows TRT to also consider FP16 kernels and
|
|
||||||
// use them in situations where they are faster than INT8 or where INT8 is
|
|
||||||
// not supported for a given layer.
|
|
||||||
builder_->setFp16Mode(true);
|
|
||||||
builder_->setInt8Mode(true);
|
|
||||||
}
|
|
||||||
ASSERT_EQ(nullptr, engine_.get());
|
ASSERT_EQ(nullptr, engine_.get());
|
||||||
builder_->setMaxBatchSize(batch_size);
|
TF_ASSERT_OK(
|
||||||
TF_ASSERT_OK(converter_->BuildCudaEngine(&engine_));
|
converter_->BuildCudaEngine(&engine_,
|
||||||
|
/*max_batch_size=*/batch_size,
|
||||||
|
/*max_workspace_size_bytes=*/1 << 26,
|
||||||
|
/*allocator=*/nullptr,
|
||||||
|
/*calibrator=*/nullptr));
|
||||||
CHECK_NOTNULL(engine_.get());
|
CHECK_NOTNULL(engine_.get());
|
||||||
CheckDataTypeMatches(input_data);
|
CheckDataTypeMatches(input_data);
|
||||||
CheckDataTypeMatches(*output_data);
|
CheckDataTypeMatches(*output_data);
|
||||||
@ -1473,7 +1458,6 @@ class OpConverterTest : public ::testing::Test {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Logger logger_;
|
Logger logger_;
|
||||||
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
|
||||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
|
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
|
||||||
cudaStream_t stream_;
|
cudaStream_t stream_;
|
||||||
// Used to create placeholders with shape and data type information. The
|
// Used to create placeholders with shape and data type information. The
|
||||||
|
Loading…
Reference in New Issue
Block a user