Fix the unit test broken by PR #34288.

PiperOrigin-RevId: 283739627
Change-Id: If2de1fd30bd8c0e3c77d63ac2b594c2a71383ac5
This commit is contained in:
Guangda Lai 2019-12-04 05:53:57 -08:00 committed by TensorFlower Gardener
parent fd9697d81d
commit 222977dffd

View File

@ -654,9 +654,8 @@ class ConverterTest : public ::testing::Test {
ConverterTest() { Reset(); }
void Reset() {
builder_.reset(nvinfer1::createInferBuilder(logger_));
converter_ =
std::move(Converter::Create(builder_.get(), TrtPrecisionMode::FP32,
std::move(Converter::Create(TrtPrecisionMode::FP32,
/*use_calibration=*/false, &logger_)
.ValueOrDie());
weight_store_ = &converter_->weight_store_;
@ -702,9 +701,6 @@ class ConverterTest : public ::testing::Test {
private:
Logger logger_;
// These members are ordered in a way such that the destruction order is:
// converter_ -> builder_
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
protected:
std::unique_ptr<Converter> converter_;
@ -996,9 +992,7 @@ TEST_F(ConverterTest, MaybeApplyQuantizationRanges) {
FakeITensor input, infer_1, infer_2, infer_3;
FakeITensor not_infer;
Logger logger;
TrtUniquePtrType<nvinfer1::IBuilder> builder(
nvinfer1::createInferBuilder(logger));
auto int8_converter = Converter::Create(builder.get(), TrtPrecisionMode::INT8,
auto int8_converter = Converter::Create(TrtPrecisionMode::INT8,
/*use_calibration=*/true, &logger)
.ValueOrDie();
int8_converter->ProvideQuantizationRange(&input, -5.0f, 5.0f);
@ -1255,12 +1249,8 @@ class OpConverterTest : public ::testing::Test {
engine_.reset(nullptr);
// Re-create them in proper order.
builder_.reset(nvinfer1::createInferBuilder(logger_));
builder_->setMaxWorkspaceSize(1 << 26);
// Reset the converter.
converter_ =
std::move(Converter::Create(builder_.get(), precision_mode_to_test_,
std::move(Converter::Create(precision_mode_to_test_,
/*use_calibration=*/false, &logger_)
.ValueOrDie());
@ -1294,18 +1284,13 @@ class OpConverterTest : public ::testing::Test {
TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
// Build the TRT engine.
if (precision_mode == TrtPrecisionMode::FP16) {
builder_->setFp16Mode(true);
} else if (precision_mode == TrtPrecisionMode::INT8) {
// Setting FP16 mode as well allows TRT to also consider FP16 kernels and
// use them in situations where they are faster than INT8 or where INT8 is
// not supported for a given layer.
builder_->setFp16Mode(true);
builder_->setInt8Mode(true);
}
ASSERT_EQ(nullptr, engine_.get());
builder_->setMaxBatchSize(batch_size);
TF_ASSERT_OK(converter_->BuildCudaEngine(&engine_));
TF_ASSERT_OK(
converter_->BuildCudaEngine(&engine_,
/*max_batch_size=*/batch_size,
/*max_workspace_size_bytes=*/1 << 26,
/*allocator=*/nullptr,
/*calibrator=*/nullptr));
CHECK_NOTNULL(engine_.get());
CheckDataTypeMatches(input_data);
CheckDataTypeMatches(*output_data);
@ -1473,7 +1458,6 @@ class OpConverterTest : public ::testing::Test {
private:
Logger logger_;
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
cudaStream_t stream_;
// Used to create placeholders with shape and data type information. The