Use IScaleLayer for BiasAdd when doing conversion in int mode without
calibration, so that TRT can fuse the Conv2D and BiasAdd layers and won't require range information for the output of the Conv2D. PiperOrigin-RevId: 255183036
This commit is contained in:
parent
f7415d1efb
commit
f6209aba67
@ -718,7 +718,8 @@ Status ConvertAfterShapes(const ConversionParams& params) {
|
||||
}
|
||||
segment_options.minimum_segment_size = params.minimum_segment_size;
|
||||
segment::SegmentNodesVector initial_segments;
|
||||
TrtNodeValidator validator(*params.graph_properties, params.precision_mode);
|
||||
TrtNodeValidator validator(*params.graph_properties, params.precision_mode,
|
||||
params.use_calibration);
|
||||
TF_RETURN_IF_ERROR(segment::SegmentGraph(
|
||||
&graph,
|
||||
std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator,
|
||||
|
@ -955,6 +955,31 @@ TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype,
|
||||
return weights;
|
||||
}
|
||||
|
||||
OpConverterParams::OpConverterParams(
|
||||
const NodeDef& node_def, const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
std::vector<TRT_TensorOrWeights>* outputs, TrtWeightStore* weight_store,
|
||||
TrtPrecisionMode precision_mode, bool use_calibration)
|
||||
: node_def(node_def),
|
||||
inputs(inputs),
|
||||
outputs(outputs),
|
||||
validation_only(true),
|
||||
weight_store(weight_store),
|
||||
precision_mode(precision_mode),
|
||||
use_calibration(use_calibration) {}
|
||||
|
||||
OpConverterParams::OpConverterParams(
|
||||
Converter* converter, const NodeDef& node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
std::vector<TRT_TensorOrWeights>* outputs, TrtWeightStore* weight_store)
|
||||
: converter(converter),
|
||||
node_def(node_def),
|
||||
inputs(inputs),
|
||||
outputs(outputs),
|
||||
validation_only(false),
|
||||
weight_store(weight_store),
|
||||
precision_mode(converter->precision_mode()),
|
||||
use_calibration(converter->use_calibration()) {}
|
||||
|
||||
const std::set<string>* TrtNodeValidator::quantize_ops = new std::set<string>{
|
||||
"QuantizeAndDequantizeV2",
|
||||
"QuantizeAndDequantizeV3",
|
||||
@ -964,8 +989,10 @@ const std::set<string>* TrtNodeValidator::quantize_ops = new std::set<string>{
|
||||
|
||||
TrtNodeValidator::TrtNodeValidator(
|
||||
const grappler::GraphProperties& graph_properties,
|
||||
TrtPrecisionMode precision_mode)
|
||||
: graph_properties_(graph_properties), precision_mode_(precision_mode) {
|
||||
TrtPrecisionMode precision_mode, bool use_calibration)
|
||||
: graph_properties_(graph_properties),
|
||||
precision_mode_(precision_mode),
|
||||
use_calibration_(use_calibration) {
|
||||
RegisterOpValidators();
|
||||
}
|
||||
|
||||
@ -1044,9 +1071,8 @@ Status TrtNodeValidator::IsTensorRTCandidate(const Node* node) {
|
||||
}
|
||||
|
||||
OpConverter validator = op_validators_[op];
|
||||
OpConverterParams params(
|
||||
/*arg_converter=*/nullptr, node->def(), inputs, /*arg_outputs=*/nullptr,
|
||||
/*arg_validation_only=*/true, &weight_store_);
|
||||
OpConverterParams params(node->def(), inputs, /*arg_outputs=*/nullptr,
|
||||
&weight_store_, precision_mode_, use_calibration_);
|
||||
return validator(¶ms);
|
||||
}
|
||||
|
||||
@ -1055,9 +1081,8 @@ Status TrtNodeValidator::ConvertConstToWeights(
|
||||
const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
TRT_TensorOrWeights* output) {
|
||||
std::vector<TRT_TensorOrWeights> outputs;
|
||||
OpConverterParams params(
|
||||
/*arg_converter=*/nullptr, const_node_def, inputs, &outputs,
|
||||
/*arg_validation_only=*/true, &weight_store_);
|
||||
OpConverterParams params(const_node_def, inputs, &outputs, &weight_store_,
|
||||
precision_mode_, use_calibration_);
|
||||
Status status = op_validators_["Const"](¶ms);
|
||||
if (status.ok() && output) *output = outputs[0];
|
||||
return status;
|
||||
@ -1109,8 +1134,7 @@ Status Converter::ConvertNode(const NodeDef& node_def) {
|
||||
std::vector<TRT_TensorOrWeights> inputs, outputs;
|
||||
TF_RETURN_IF_ERROR(this->GetInputs(node_def, &inputs));
|
||||
|
||||
OpConverterParams params(this, node_def, inputs, &outputs,
|
||||
/*arg_validation_only=*/false, &weight_store_);
|
||||
OpConverterParams params(this, node_def, inputs, &outputs, &weight_store_);
|
||||
const string& op = node_def.op();
|
||||
auto itr = op_registry_.find(op);
|
||||
if (itr == op_registry_.end()) {
|
||||
@ -2796,7 +2820,7 @@ Status ConvertRelu6(OpConverterParams* params) {
|
||||
#endif
|
||||
}
|
||||
|
||||
Status ConvertBiasAdd(OpConverterParams* params) {
|
||||
Status ConvertBiasAddInt8WithoutCalibration(OpConverterParams* params) {
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -2894,6 +2918,71 @@ Status ConvertBiasAdd(OpConverterParams* params) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertBiasAdd(OpConverterParams* params) {
|
||||
if (params->precision_mode == TrtPrecisionMode::INT8 &&
|
||||
!params->use_calibration) {
|
||||
// NOTE(laigd): based on some observation, it seems TensorRT cannot fuse
|
||||
// IConvolutionLayer and IElementwiseLayer and will require range
|
||||
// information for the output of Conv2D. Using IScaleLayer will fix the
|
||||
// problem.
|
||||
return ConvertBiasAddInt8WithoutCalibration(params);
|
||||
}
|
||||
const auto& inputs = params->inputs;
|
||||
const auto& node_def = params->node_def;
|
||||
|
||||
if (inputs.size() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"BiasAdd expects exactly 2 inputs, but received ", inputs.size());
|
||||
}
|
||||
|
||||
if (inputs[0].is_weights() && inputs[1].is_weights()) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs are weights, but Grappler is expected to fold them.");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
|
||||
TFAttrs attrs(node_def);
|
||||
const string& data_format = attrs.get<string>("data_format");
|
||||
|
||||
nvinfer1::Dims input_shape = inputs.at(0).GetTrtDims();
|
||||
nvinfer1::Dims bias_shape = inputs.at(1).GetTrtDims();
|
||||
// If the input is NCHW, then we need to unsqueeze the bias such that its last
|
||||
// dimensions are 1s (and the first dimension is C).
|
||||
if (data_format == "NCHW") {
|
||||
bias_shape.nbDims = input_shape.nbDims;
|
||||
std::fill(bias_shape.d + 1, bias_shape.d + bias_shape.nbDims, 1);
|
||||
} else {
|
||||
// Next, broadcast the bias across the input.
|
||||
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(inputs.at(0), inputs.at(1),
|
||||
&input_shape, &bias_shape));
|
||||
}
|
||||
|
||||
// Convert input to a TRT tensor
|
||||
nvinfer1::ITensor* input_tensor{nullptr};
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
inputs.at(0), input_shape, params->validation_only, &input_tensor));
|
||||
|
||||
// Finally, reshape bias. Since the bias is usually a constant, this will
|
||||
// normally happen at conversion-time.
|
||||
nvinfer1::ITensor* bias_tensor{nullptr};
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
inputs.at(1), bias_shape, params->validation_only, &bias_tensor));
|
||||
VLOG(2) << "Bias shape adjusted to " << DebugString(bias_shape);
|
||||
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
nvinfer1::IElementWiseLayer* layer =
|
||||
params->converter->network()->addElementWise(
|
||||
*input_tensor, *bias_tensor, nvinfer1::ElementWiseOperation::kSUM);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
||||
|
||||
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void GetTensorDimsWithProtoShape(const Tensor& tensor, nvinfer1::Dims* dims) {
|
||||
if (tensor.dims() > 0) {
|
||||
*dims = GetTrtDimsForTensor(tensor);
|
||||
|
@ -353,23 +353,27 @@ class Converter;
|
||||
|
||||
// Parameters for each op converter.
|
||||
struct OpConverterParams {
|
||||
OpConverterParams(Converter* arg_converter, const NodeDef& arg_node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& arg_inputs,
|
||||
std::vector<TRT_TensorOrWeights>* arg_outputs,
|
||||
bool arg_validation_only, TrtWeightStore* arg_weight_store)
|
||||
: converter(arg_converter),
|
||||
node_def(arg_node_def),
|
||||
inputs(arg_inputs),
|
||||
outputs(arg_outputs),
|
||||
validation_only(arg_validation_only),
|
||||
weight_store(arg_weight_store) {}
|
||||
// Constructor used for validation only.
|
||||
OpConverterParams(const NodeDef& node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
std::vector<TRT_TensorOrWeights>* outputs,
|
||||
TrtWeightStore* weight_store,
|
||||
TrtPrecisionMode precision_mode, bool use_calibration);
|
||||
|
||||
Converter* converter;
|
||||
// Constructor used for conversion.
|
||||
OpConverterParams(Converter* converter, const NodeDef& node_def,
|
||||
const std::vector<TRT_TensorOrWeights>& inputs,
|
||||
std::vector<TRT_TensorOrWeights>* outputs,
|
||||
TrtWeightStore* weight_store);
|
||||
|
||||
Converter* converter = nullptr;
|
||||
const NodeDef& node_def;
|
||||
const std::vector<TRT_TensorOrWeights>& inputs;
|
||||
std::vector<TRT_TensorOrWeights>* outputs;
|
||||
const bool validation_only;
|
||||
TrtWeightStore* weight_store;
|
||||
const TrtPrecisionMode precision_mode;
|
||||
const bool use_calibration;
|
||||
};
|
||||
|
||||
using OpConverter = std::function<Status(OpConverterParams*)>;
|
||||
@ -381,7 +385,7 @@ class TrtNodeValidator {
|
||||
// checked by IsTensorRTCandidate() later. It is used to get the shape and
|
||||
// data type information of a tensor for validation purpose.
|
||||
TrtNodeValidator(const grappler::GraphProperties& graph_properties,
|
||||
TrtPrecisionMode precision_mode);
|
||||
TrtPrecisionMode precision_mode, bool use_calibration);
|
||||
|
||||
// Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
|
||||
// to TRT subgraph and later converted into TRT engine.
|
||||
@ -419,6 +423,8 @@ class TrtNodeValidator {
|
||||
// Quantization ops are only converted when using quantized precisions.
|
||||
const TrtPrecisionMode precision_mode_;
|
||||
|
||||
const bool use_calibration_;
|
||||
|
||||
friend class ValidatorTest;
|
||||
friend class OpConverterTest;
|
||||
};
|
||||
|
@ -460,7 +460,8 @@ class ValidatorTest : public ::testing::Test {
|
||||
grappler::GraphProperties graph_properties(item);
|
||||
TF_EXPECT_OK(graph_properties.InferStatically(true));
|
||||
|
||||
TrtNodeValidator validator(graph_properties, TrtPrecisionMode::FP32);
|
||||
TrtNodeValidator validator(graph_properties, TrtPrecisionMode::FP32,
|
||||
/*use_calibration=*/false);
|
||||
return validator.ConvertToTensorOrWeights(node->def(), output_port,
|
||||
tensor_or_weights);
|
||||
}
|
||||
@ -473,7 +474,8 @@ class ValidatorTest : public ::testing::Test {
|
||||
TEST_F(ValidatorTest, QuantizeOpsAreRegistered) {
|
||||
grappler::GrapplerItem item;
|
||||
grappler::GraphProperties graph_properties(item);
|
||||
TrtNodeValidator validator(graph_properties, TrtPrecisionMode::FP32);
|
||||
TrtNodeValidator validator(graph_properties, TrtPrecisionMode::FP32,
|
||||
/*use_calibration=*/false);
|
||||
for (const string& quantize_op : *GetQuantizeOps(&validator)) {
|
||||
QCHECK(op_validators(&validator).count(quantize_op));
|
||||
}
|
||||
@ -542,7 +544,8 @@ TEST_F(ValidatorTest, IsTensorRTCandidate_Basics) {
|
||||
TF_EXPECT_OK(s.ToGraphDef(&item.graph));
|
||||
grappler::GraphProperties graph_properties(item);
|
||||
TF_EXPECT_OK(graph_properties.InferStatically(true));
|
||||
TrtNodeValidator validator(graph_properties, TrtPrecisionMode::FP32);
|
||||
TrtNodeValidator validator(graph_properties, TrtPrecisionMode::FP32,
|
||||
/*use_calibration=*/false);
|
||||
|
||||
bool start_conversion = false;
|
||||
bool should_fail = false;
|
||||
@ -620,7 +623,8 @@ TEST(TrtNodeValidator, IsTensorRTCandidate) {
|
||||
|
||||
for (const TrtPrecisionMode precision_mode :
|
||||
{TrtPrecisionMode::FP32, TrtPrecisionMode::INT8}) {
|
||||
TrtNodeValidator validator(graph_properties, precision_mode);
|
||||
TrtNodeValidator validator(graph_properties, precision_mode,
|
||||
/*use_calibration=*/false);
|
||||
TF_EXPECT_OK(validator.IsTensorRTCandidate(matmul.operation.node()));
|
||||
ExpectStatus(
|
||||
validator.IsTensorRTCandidate(incompatible_matmul.operation.node()),
|
||||
@ -1402,7 +1406,8 @@ class OpConverterTest : public ::testing::Test {
|
||||
grappler::GraphProperties graph_properties(item);
|
||||
TF_EXPECT_OK(graph_properties.InferStatically(true));
|
||||
|
||||
TrtNodeValidator validator(graph_properties, precision_mode_to_test_);
|
||||
TrtNodeValidator validator(graph_properties, precision_mode_to_test_,
|
||||
/*use_calibration=*/false);
|
||||
ExpectStatus(validator.IsTensorRTCandidate(node), expected_code,
|
||||
expected_msg_substr);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user