Save calibration table after calibration, so it can support multiple engines in

int8 mode.

PiperOrigin-RevId: 263857264
This commit is contained in:
Guangda Lai 2019-08-16 15:19:13 -07:00 committed by Guangda Lai
parent ed47e70590
commit 163f9df4c7
7 changed files with 112 additions and 72 deletions

View File

@ -48,12 +48,10 @@ class GetCalibrationDataOp : public OpKernel {
&resource));
core::ScopedUnref sc(resource);
auto* calib_ctx = resource->calib_ctx_.get();
// Serialize the resource as output.
string serialized_resource;
OP_REQUIRES_OK(context, calib_ctx->SerializeToString(&serialized_resource));
resource->calib_ctx_.reset();
string serialized_resource = resource->calib_ctx_->TerminateCalibration();
OP_REQUIRES(context, !serialized_resource.empty(),
errors::Unknown("Calibration table is empty."));
Tensor* output = nullptr;
OP_REQUIRES_OK(context,

View File

@ -837,19 +837,18 @@ Status TRTEngineOp::AllocateCalibrationResources(
if (!s.ok()) {
LOG(ERROR) << "Calibration failed: " << s;
cres->calibrator_->setDone(); // Ignore further pushes
} else {
// Transfer the ownership of the engine to the engine cache, so we can
// dump it out during conversion for TF 2.0.
mutex_lock lock(this->engine_mutex_);
this->calibrator_ = std::move(cres->calibrator_);
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
cres->engine_->createExecutionContext());
cache_res->cache_.emplace(
shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
std::move(exec_context)));
}
// Transfer the ownership of the engine to the engine cache, so we can
// dump it out during conversion for TF 2.0.
mutex_lock lock(this->engine_mutex_);
cres->SetCalibrationTable();
this->calibrator_ = std::move(cres->calibrator_);
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
cres->engine_->createExecutionContext());
cache_res->cache_.emplace(
shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
std::move(exec_context)));
VLOG(1) << "Calibration loop terminated " << this->name();
}));
VLOG(1) << "initialized calibrator resource";

View File

@ -184,13 +184,7 @@ class SerializeTRTResource : public OpKernel {
core::ScopedUnref unref_me(resource);
// Terminate the calibration if any.
if (resource->calib_ctx_) {
// We don't save the calibration_table for TF 2.0 at the moment, it's used
// in 1.x environment.
string calibration_table;
OP_REQUIRES_OK(
ctx, resource->calib_ctx_->SerializeToString(&calibration_table));
}
if (resource->calib_ctx_) resource->calib_ctx_->TerminateCalibration();
// Serialize the engines and write them to file.
std::unique_ptr<WritableFile> file;

View File

@ -30,6 +30,26 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
string CalibrationContext::TerminateCalibration() {
mutex_lock l(mu_);
if (terminated_) return calibration_table_;
TRTInt8Calibrator* raw_calibrator = calibrator_.get();
raw_calibrator->waitAndSetDone();
terminated_ = true;
// At this point the calibration thread `thr_` is woken up and can
// transfer the ownership of `calibrator_` and `engine_` at any time, so
// it's not safe to use `calibrator_` below, but we can still access it
// using raw pointer.
// TODO(laigd): make TRTEngineOp::AllocateCalibrationResources() a member
// function of this class instead.
thr_->join();
calibration_table_ = raw_calibrator->getCalibrationTableAsString();
return calibration_table_;
}
const absl::string_view kTfTrtContainerName = "TF-TRT";
Logger& TRTEngineCacheResource::GetLogger() {

View File

@ -142,19 +142,7 @@ struct EngineContext {
// Contains the context required to build the calibration data.
class CalibrationContext {
public:
void SetCalibrationTable() {
calibration_table_ = calibrator_->getCalibrationTableAsString();
}
Status SerializeToString(string* serialized) {
calibrator_->waitAndSetDone();
thr_->join();
*serialized = calibration_table_;
if (serialized->empty()) {
return errors::Unknown("Calibration table is empty.");
}
return Status::OK();
}
string TerminateCalibration();
// Lookup table for temporary staging areas of input tensors for calibration.
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
@ -162,12 +150,16 @@ class CalibrationContext {
// Temporary staging areas for calibration inputs.
std::vector<PersistentTensor> device_tensors_;
string calibration_table_;
std::unique_ptr<TRTInt8Calibrator> calibrator_;
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
// TODO(sami): Use threadpool threads!
std::unique_ptr<std::thread> thr_;
private:
mutex mu_;
bool terminated_ GUARDED_BY(mu_) = false;
std::string calibration_table_ GUARDED_BY(mu_);
};
ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName;

View File

@ -100,6 +100,7 @@ class TrtPrecisionMode(object):
]
return precisions + [p.lower() for p in precisions]
# Use a large enough number as the default max_workspace_size for TRT engines,
# so it can produce reasonable performance results with the default.
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
@ -263,8 +264,7 @@ def get_tensorrt_rewriter_config(
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
optimizer.parameter_map[
"use_calibration"].b = conversion_params.use_calibration
optimizer.parameter_map[
"max_batch_size"].i = conversion_params.max_batch_size
optimizer.parameter_map["max_batch_size"].i = conversion_params.max_batch_size
optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op
return rewriter_config_with_trt
@ -955,14 +955,20 @@ class TrtGraphConverterV2(object):
engine_asset_dir = tempfile.mkdtemp()
resource_map = {}
def _serialize_and_track_engine(canonical_engine_name):
def _serialize_and_track_engine(node):
"""Serialize TRT engines in the cache and track them."""
# Don't dump the same cache twice.
canonical_engine_name = _get_canonical_engine_name(node.name)
if canonical_engine_name in resource_map:
return
filename = os.path.join(engine_asset_dir,
"trt-serialized-engine." + canonical_engine_name)
if self._need_calibration:
calibration_table = gen_trt_ops.get_calibration_data_op(
canonical_engine_name)
node.attr["calibration_data"].s = calibration_table.numpy()
try:
gen_trt_ops.serialize_trt_resource(
resource_name=canonical_engine_name,
@ -980,19 +986,28 @@ class TrtGraphConverterV2(object):
for node in self._converted_graph_def.node:
if node.op == _TRT_ENGINE_OP_NAME:
_serialize_and_track_engine(_get_canonical_engine_name(node.name))
_serialize_and_track_engine(node)
for func in self._converted_graph_def.library.function:
for node in func.node_def:
if node.op == _TRT_ENGINE_OP_NAME:
_serialize_and_track_engine(canonical_engine_name(node))
_serialize_and_track_engine(node)
self._saved_model.trt_engine_resources = resource_map
# Rebuild the function since calibration may change the graph.
func_to_save = wrap_function.function_from_graph_def(
self._converted_graph_def,
[tensor.name for tensor in self._converted_func.inputs],
[tensor.name for tensor in self._converted_func.outputs])
func_to_save.graph.structured_outputs = nest.pack_sequence_as(
self._converted_func.graph.structured_outputs,
func_to_save.graph.structured_outputs)
# Rewrite the signature map using the optimized ConcreteFunction.
signatures = {
key: value for key, value in self._saved_model.signatures.items()
}
signatures[self._input_saved_model_signature_key] = self._converted_func
signatures[self._input_saved_model_signature_key] = func_to_save
save.save(self._saved_model, output_saved_model_dir, signatures)

View File

@ -322,6 +322,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
maximum_cached_engines=2,
max_batch_size=max_batch_size if max_batch_size else 1))
def _CheckTrtOps(self, concrete_func, check_fn=None):
graph_def = concrete_func.graph.as_graph_def()
trt_op_names = []
for node in graph_def.node:
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
if check_fn:
check_fn(node)
for func in graph_def.library.function:
for node in func.node_def:
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
if check_fn:
check_fn(node)
self.assertEqual(1, len(trt_op_names))
self.assertIn("TRTEngineOp_0", trt_op_names[0])
@test_util.run_v2_only
def testTrtGraphConverter_DynamicConversion_v2(self):
"""Test case for trt_convert.TrtGraphConverter()."""
@ -341,22 +358,11 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
converter = self._CreateConverterV2(input_saved_model_dir)
converted_func = converter.convert()
def _CheckTrtOps(graph_def):
trt_op_names = [
node.name for node in graph_def.node if node.op == "TRTEngineOp"
]
for func in graph_def.library.function:
for node in func.node_def:
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
self.assertEqual(1, len(trt_op_names))
self.assertIn("TRTEngineOp_0", trt_op_names[0])
# Verify the converted GraphDef and ConcreteFunction.
self.assertIsInstance(converted_func, def_function.Function)
converted_concrete_func = converted_func.get_concrete_function(
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
_CheckTrtOps(converted_concrete_func.graph.as_graph_def())
self._CheckTrtOps(converted_concrete_func)
# Save the converted model without any TRT engine cache.
output_saved_model_dir = self.mkdtemp()
@ -382,6 +388,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
self.assertTrue(os.path.exists(expected_asset_file))
self.assertTrue(os.path.getsize(expected_asset_file))
del converter
gc.collect() # Force GC to destroy the TRT engine cache.
# Load and verify the converted model.
#
# TODO(laigd): the name of the new input_signature of the
@ -390,16 +399,18 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
root_with_trt = load.load(output_saved_model_dir)
# TODO(laigd): `root_with_trt.run` is still using the original graph without
# trt. Consider changing that.
# _CheckTrtOps(
# root_with_trt.run.get_concrete_function().graph.as_graph_def())
# self._CheckTrtOps(root_with_trt.run.get_concrete_function())
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
_CheckTrtOps(converted_signature.graph.as_graph_def())
self._CheckTrtOps(converted_signature)
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
# The output of running the converted signature is a dict due to
# compatibility reasons with V1 SavedModel signature mechanism.
output_with_trt = output_with_trt.values()[0]
self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6)
del root_with_trt
gc.collect() # Force GC to destroy the TRT engine cache.
@test_util.run_v2_only
def testTrtGraphConverter_StaticConversion_v2(self):
"""Test case for trt_convert.TrtGraphConverter() using static mode."""
@ -419,23 +430,14 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
converter = self._CreateConverterV2(input_saved_model_dir, max_batch_size=4)
converted_func = converter.convert()
def _CheckTrtOps(graph_def):
trt_op_names = [
node.name for node in graph_def.node if node.op == "TRTEngineOp"
]
for func in graph_def.library.function:
for node in func.node_def:
if node.op == "TRTEngineOp":
trt_op_names.append(node.name)
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
self.assertEqual(1, len(trt_op_names))
self.assertIn("TRTEngineOp_0", trt_op_names[0])
def _CheckFn(node):
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
# Verify the converted GraphDef and ConcreteFunction.
self.assertIsInstance(converted_func, def_function.Function)
converted_concrete_func = converted_func.get_concrete_function(
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
_CheckTrtOps(converted_concrete_func.graph.as_graph_def())
self._CheckTrtOps(converted_concrete_func, _CheckFn)
# Save the converted model with the statically-built engine inlined.
output_saved_model_dir = self.mkdtemp()
@ -444,10 +446,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
self.assertFalse(os.path.exists(unexpected_asset_file))
del converter
gc.collect() # Force GC to destroy the TRT engine cache.
# Load and verify the converted model.
root_with_trt = load.load(output_saved_model_dir)
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
_CheckTrtOps(converted_signature.graph.as_graph_def())
self._CheckTrtOps(converted_signature, _CheckFn)
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
# The output of running the converted signature is a dict due to
# compatibility reasons with V1 SavedModel signature mechanism.
@ -457,6 +462,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
atol=1e-6,
rtol=1e-6)
del root_with_trt
gc.collect() # Force GC to destroy the TRT engine cache.
@test_util.run_v2_only
def testTrtGraphConverter_Int8Conversion_v2(self):
if not is_tensorrt_enabled():
@ -493,12 +501,18 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
self.assertTrue(os.path.exists(expected_asset_file))
self.assertTrue(os.path.getsize(expected_asset_file))
del converter
gc.collect() # Force GC to destroy the TRT engine cache.
def _CheckFn(node):
self.assertTrue(len(node.attr["calibration_data"].s), node.name)
# Load and verify the converted model.
root_with_trt = load.load(output_saved_model_dir)
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
self._CheckTrtOps(converted_signature, _CheckFn)
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
self.assertEqual(1, len(output_with_trt))
# The output of running the converted signature is a dict due to
# compatibility reasons with V1 SavedModel signature mechanism.
self.assertAllClose(
@ -507,6 +521,14 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
atol=1e-6,
rtol=1e-6)
# Run with an input of different batch size. It should build a new engine
# using calibration table.
np_input = np.random.random_sample([5, 1, 1]).astype(np.float32)
converted_signature(ops.convert_to_tensor(np_input))
del root_with_trt
gc.collect() # Force GC to destroy the TRT engine cache.
@test_util.run_v2_only
def testTrtGraphConverter_DestroyEngineCache(self):
"""Test case for trt_convert.TrtGraphConverter()."""