Save calibration table after calibration, so it can support multiple engines in
int8 mode. PiperOrigin-RevId: 263857264
This commit is contained in:
parent
ed47e70590
commit
163f9df4c7
@ -48,12 +48,10 @@ class GetCalibrationDataOp : public OpKernel {
|
||||
&resource));
|
||||
core::ScopedUnref sc(resource);
|
||||
|
||||
auto* calib_ctx = resource->calib_ctx_.get();
|
||||
|
||||
// Serialize the resource as output.
|
||||
string serialized_resource;
|
||||
OP_REQUIRES_OK(context, calib_ctx->SerializeToString(&serialized_resource));
|
||||
resource->calib_ctx_.reset();
|
||||
string serialized_resource = resource->calib_ctx_->TerminateCalibration();
|
||||
OP_REQUIRES(context, !serialized_resource.empty(),
|
||||
errors::Unknown("Calibration table is empty."));
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
|
@ -837,19 +837,18 @@ Status TRTEngineOp::AllocateCalibrationResources(
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Calibration failed: " << s;
|
||||
cres->calibrator_->setDone(); // Ignore further pushes
|
||||
} else {
|
||||
// Transfer the ownership of the engine to the engine cache, so we can
|
||||
// dump it out during conversion for TF 2.0.
|
||||
mutex_lock lock(this->engine_mutex_);
|
||||
this->calibrator_ = std::move(cres->calibrator_);
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||
cres->engine_->createExecutionContext());
|
||||
cache_res->cache_.emplace(
|
||||
shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
|
||||
std::move(exec_context)));
|
||||
}
|
||||
|
||||
// Transfer the ownership of the engine to the engine cache, so we can
|
||||
// dump it out during conversion for TF 2.0.
|
||||
mutex_lock lock(this->engine_mutex_);
|
||||
cres->SetCalibrationTable();
|
||||
this->calibrator_ = std::move(cres->calibrator_);
|
||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||
cres->engine_->createExecutionContext());
|
||||
cache_res->cache_.emplace(
|
||||
shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
|
||||
std::move(exec_context)));
|
||||
|
||||
VLOG(1) << "Calibration loop terminated " << this->name();
|
||||
}));
|
||||
VLOG(1) << "initialized calibrator resource";
|
||||
|
@ -184,13 +184,7 @@ class SerializeTRTResource : public OpKernel {
|
||||
core::ScopedUnref unref_me(resource);
|
||||
|
||||
// Terminate the calibration if any.
|
||||
if (resource->calib_ctx_) {
|
||||
// We don't save the calibration_table for TF 2.0 at the moment, it's used
|
||||
// in 1.x environment.
|
||||
string calibration_table;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, resource->calib_ctx_->SerializeToString(&calibration_table));
|
||||
}
|
||||
if (resource->calib_ctx_) resource->calib_ctx_->TerminateCalibration();
|
||||
|
||||
// Serialize the engines and write them to file.
|
||||
std::unique_ptr<WritableFile> file;
|
||||
|
@ -30,6 +30,26 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
|
||||
string CalibrationContext::TerminateCalibration() {
|
||||
mutex_lock l(mu_);
|
||||
if (terminated_) return calibration_table_;
|
||||
|
||||
TRTInt8Calibrator* raw_calibrator = calibrator_.get();
|
||||
raw_calibrator->waitAndSetDone();
|
||||
terminated_ = true;
|
||||
|
||||
// At this point the calibration thread `thr_` is woken up and can
|
||||
// transfer the ownership of `calibrator_` and `engine_` at any time, so
|
||||
// it's not safe to use `calibrator_` below, but we can still access it
|
||||
// using raw pointer.
|
||||
// TODO(laigd): make TRTEngineOp::AllocateCalibrationResources() a member
|
||||
// function of this class instead.
|
||||
|
||||
thr_->join();
|
||||
calibration_table_ = raw_calibrator->getCalibrationTableAsString();
|
||||
return calibration_table_;
|
||||
}
|
||||
|
||||
const absl::string_view kTfTrtContainerName = "TF-TRT";
|
||||
|
||||
Logger& TRTEngineCacheResource::GetLogger() {
|
||||
|
@ -142,19 +142,7 @@ struct EngineContext {
|
||||
// Contains the context required to build the calibration data.
|
||||
class CalibrationContext {
|
||||
public:
|
||||
void SetCalibrationTable() {
|
||||
calibration_table_ = calibrator_->getCalibrationTableAsString();
|
||||
}
|
||||
|
||||
Status SerializeToString(string* serialized) {
|
||||
calibrator_->waitAndSetDone();
|
||||
thr_->join();
|
||||
*serialized = calibration_table_;
|
||||
if (serialized->empty()) {
|
||||
return errors::Unknown("Calibration table is empty.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
string TerminateCalibration();
|
||||
|
||||
// Lookup table for temporary staging areas of input tensors for calibration.
|
||||
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
|
||||
@ -162,12 +150,16 @@ class CalibrationContext {
|
||||
// Temporary staging areas for calibration inputs.
|
||||
std::vector<PersistentTensor> device_tensors_;
|
||||
|
||||
string calibration_table_;
|
||||
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
||||
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
|
||||
// TODO(sami): Use threadpool threads!
|
||||
std::unique_ptr<std::thread> thr_;
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
bool terminated_ GUARDED_BY(mu_) = false;
|
||||
std::string calibration_table_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName;
|
||||
|
@ -100,6 +100,7 @@ class TrtPrecisionMode(object):
|
||||
]
|
||||
return precisions + [p.lower() for p in precisions]
|
||||
|
||||
|
||||
# Use a large enough number as the default max_workspace_size for TRT engines,
|
||||
# so it can produce reasonable performance results with the default.
|
||||
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
||||
@ -263,8 +264,7 @@ def get_tensorrt_rewriter_config(
|
||||
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
|
||||
optimizer.parameter_map[
|
||||
"use_calibration"].b = conversion_params.use_calibration
|
||||
optimizer.parameter_map[
|
||||
"max_batch_size"].i = conversion_params.max_batch_size
|
||||
optimizer.parameter_map["max_batch_size"].i = conversion_params.max_batch_size
|
||||
optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op
|
||||
return rewriter_config_with_trt
|
||||
|
||||
@ -955,14 +955,20 @@ class TrtGraphConverterV2(object):
|
||||
engine_asset_dir = tempfile.mkdtemp()
|
||||
resource_map = {}
|
||||
|
||||
def _serialize_and_track_engine(canonical_engine_name):
|
||||
def _serialize_and_track_engine(node):
|
||||
"""Serialize TRT engines in the cache and track them."""
|
||||
# Don't dump the same cache twice.
|
||||
canonical_engine_name = _get_canonical_engine_name(node.name)
|
||||
if canonical_engine_name in resource_map:
|
||||
return
|
||||
|
||||
filename = os.path.join(engine_asset_dir,
|
||||
"trt-serialized-engine." + canonical_engine_name)
|
||||
if self._need_calibration:
|
||||
calibration_table = gen_trt_ops.get_calibration_data_op(
|
||||
canonical_engine_name)
|
||||
node.attr["calibration_data"].s = calibration_table.numpy()
|
||||
|
||||
try:
|
||||
gen_trt_ops.serialize_trt_resource(
|
||||
resource_name=canonical_engine_name,
|
||||
@ -980,19 +986,28 @@ class TrtGraphConverterV2(object):
|
||||
|
||||
for node in self._converted_graph_def.node:
|
||||
if node.op == _TRT_ENGINE_OP_NAME:
|
||||
_serialize_and_track_engine(_get_canonical_engine_name(node.name))
|
||||
_serialize_and_track_engine(node)
|
||||
for func in self._converted_graph_def.library.function:
|
||||
for node in func.node_def:
|
||||
if node.op == _TRT_ENGINE_OP_NAME:
|
||||
_serialize_and_track_engine(canonical_engine_name(node))
|
||||
_serialize_and_track_engine(node)
|
||||
|
||||
self._saved_model.trt_engine_resources = resource_map
|
||||
|
||||
# Rebuild the function since calibration may change the graph.
|
||||
func_to_save = wrap_function.function_from_graph_def(
|
||||
self._converted_graph_def,
|
||||
[tensor.name for tensor in self._converted_func.inputs],
|
||||
[tensor.name for tensor in self._converted_func.outputs])
|
||||
func_to_save.graph.structured_outputs = nest.pack_sequence_as(
|
||||
self._converted_func.graph.structured_outputs,
|
||||
func_to_save.graph.structured_outputs)
|
||||
|
||||
# Rewrite the signature map using the optimized ConcreteFunction.
|
||||
signatures = {
|
||||
key: value for key, value in self._saved_model.signatures.items()
|
||||
}
|
||||
signatures[self._input_saved_model_signature_key] = self._converted_func
|
||||
signatures[self._input_saved_model_signature_key] = func_to_save
|
||||
save.save(self._saved_model, output_saved_model_dir, signatures)
|
||||
|
||||
|
||||
|
@ -322,6 +322,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
maximum_cached_engines=2,
|
||||
max_batch_size=max_batch_size if max_batch_size else 1))
|
||||
|
||||
def _CheckTrtOps(self, concrete_func, check_fn=None):
|
||||
graph_def = concrete_func.graph.as_graph_def()
|
||||
trt_op_names = []
|
||||
for node in graph_def.node:
|
||||
if node.op == "TRTEngineOp":
|
||||
trt_op_names.append(node.name)
|
||||
if check_fn:
|
||||
check_fn(node)
|
||||
for func in graph_def.library.function:
|
||||
for node in func.node_def:
|
||||
if node.op == "TRTEngineOp":
|
||||
trt_op_names.append(node.name)
|
||||
if check_fn:
|
||||
check_fn(node)
|
||||
self.assertEqual(1, len(trt_op_names))
|
||||
self.assertIn("TRTEngineOp_0", trt_op_names[0])
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTrtGraphConverter_DynamicConversion_v2(self):
|
||||
"""Test case for trt_convert.TrtGraphConverter()."""
|
||||
@ -341,22 +358,11 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
converter = self._CreateConverterV2(input_saved_model_dir)
|
||||
converted_func = converter.convert()
|
||||
|
||||
def _CheckTrtOps(graph_def):
|
||||
trt_op_names = [
|
||||
node.name for node in graph_def.node if node.op == "TRTEngineOp"
|
||||
]
|
||||
for func in graph_def.library.function:
|
||||
for node in func.node_def:
|
||||
if node.op == "TRTEngineOp":
|
||||
trt_op_names.append(node.name)
|
||||
self.assertEqual(1, len(trt_op_names))
|
||||
self.assertIn("TRTEngineOp_0", trt_op_names[0])
|
||||
|
||||
# Verify the converted GraphDef and ConcreteFunction.
|
||||
self.assertIsInstance(converted_func, def_function.Function)
|
||||
converted_concrete_func = converted_func.get_concrete_function(
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
||||
_CheckTrtOps(converted_concrete_func.graph.as_graph_def())
|
||||
self._CheckTrtOps(converted_concrete_func)
|
||||
|
||||
# Save the converted model without any TRT engine cache.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
@ -382,6 +388,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(os.path.exists(expected_asset_file))
|
||||
self.assertTrue(os.path.getsize(expected_asset_file))
|
||||
|
||||
del converter
|
||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||
|
||||
# Load and verify the converted model.
|
||||
#
|
||||
# TODO(laigd): the name of the new input_signature of the
|
||||
@ -390,16 +399,18 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
root_with_trt = load.load(output_saved_model_dir)
|
||||
# TODO(laigd): `root_with_trt.run` is still using the original graph without
|
||||
# trt. Consider changing that.
|
||||
# _CheckTrtOps(
|
||||
# root_with_trt.run.get_concrete_function().graph.as_graph_def())
|
||||
# self._CheckTrtOps(root_with_trt.run.get_concrete_function())
|
||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||
_CheckTrtOps(converted_signature.graph.as_graph_def())
|
||||
self._CheckTrtOps(converted_signature)
|
||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||
# The output of running the converted signature is a dict due to
|
||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||
output_with_trt = output_with_trt.values()[0]
|
||||
self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6)
|
||||
|
||||
del root_with_trt
|
||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTrtGraphConverter_StaticConversion_v2(self):
|
||||
"""Test case for trt_convert.TrtGraphConverter() using static mode."""
|
||||
@ -419,23 +430,14 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
converter = self._CreateConverterV2(input_saved_model_dir, max_batch_size=4)
|
||||
converted_func = converter.convert()
|
||||
|
||||
def _CheckTrtOps(graph_def):
|
||||
trt_op_names = [
|
||||
node.name for node in graph_def.node if node.op == "TRTEngineOp"
|
||||
]
|
||||
for func in graph_def.library.function:
|
||||
for node in func.node_def:
|
||||
if node.op == "TRTEngineOp":
|
||||
trt_op_names.append(node.name)
|
||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||
self.assertEqual(1, len(trt_op_names))
|
||||
self.assertIn("TRTEngineOp_0", trt_op_names[0])
|
||||
def _CheckFn(node):
|
||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||
|
||||
# Verify the converted GraphDef and ConcreteFunction.
|
||||
self.assertIsInstance(converted_func, def_function.Function)
|
||||
converted_concrete_func = converted_func.get_concrete_function(
|
||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
||||
_CheckTrtOps(converted_concrete_func.graph.as_graph_def())
|
||||
self._CheckTrtOps(converted_concrete_func, _CheckFn)
|
||||
|
||||
# Save the converted model with the statically-built engine inlined.
|
||||
output_saved_model_dir = self.mkdtemp()
|
||||
@ -444,10 +446,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
|
||||
self.assertFalse(os.path.exists(unexpected_asset_file))
|
||||
|
||||
del converter
|
||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||
|
||||
# Load and verify the converted model.
|
||||
root_with_trt = load.load(output_saved_model_dir)
|
||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||
_CheckTrtOps(converted_signature.graph.as_graph_def())
|
||||
self._CheckTrtOps(converted_signature, _CheckFn)
|
||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||
# The output of running the converted signature is a dict due to
|
||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||
@ -457,6 +462,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
atol=1e-6,
|
||||
rtol=1e-6)
|
||||
|
||||
del root_with_trt
|
||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTrtGraphConverter_Int8Conversion_v2(self):
|
||||
if not is_tensorrt_enabled():
|
||||
@ -493,12 +501,18 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(os.path.exists(expected_asset_file))
|
||||
self.assertTrue(os.path.getsize(expected_asset_file))
|
||||
|
||||
del converter
|
||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||
|
||||
def _CheckFn(node):
|
||||
self.assertTrue(len(node.attr["calibration_data"].s), node.name)
|
||||
|
||||
# Load and verify the converted model.
|
||||
root_with_trt = load.load(output_saved_model_dir)
|
||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||
self._CheckTrtOps(converted_signature, _CheckFn)
|
||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||
self.assertEqual(1, len(output_with_trt))
|
||||
|
||||
# The output of running the converted signature is a dict due to
|
||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||
self.assertAllClose(
|
||||
@ -507,6 +521,14 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
||||
atol=1e-6,
|
||||
rtol=1e-6)
|
||||
|
||||
# Run with an input of different batch size. It should build a new engine
|
||||
# using calibration table.
|
||||
np_input = np.random.random_sample([5, 1, 1]).astype(np.float32)
|
||||
converted_signature(ops.convert_to_tensor(np_input))
|
||||
|
||||
del root_with_trt
|
||||
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testTrtGraphConverter_DestroyEngineCache(self):
|
||||
"""Test case for trt_convert.TrtGraphConverter()."""
|
||||
|
Loading…
Reference in New Issue
Block a user