Save calibration table after calibration, so it can support multiple engines in
int8 mode. PiperOrigin-RevId: 263857264
This commit is contained in:
parent
ed47e70590
commit
163f9df4c7
@ -48,12 +48,10 @@ class GetCalibrationDataOp : public OpKernel {
|
|||||||
&resource));
|
&resource));
|
||||||
core::ScopedUnref sc(resource);
|
core::ScopedUnref sc(resource);
|
||||||
|
|
||||||
auto* calib_ctx = resource->calib_ctx_.get();
|
|
||||||
|
|
||||||
// Serialize the resource as output.
|
// Serialize the resource as output.
|
||||||
string serialized_resource;
|
string serialized_resource = resource->calib_ctx_->TerminateCalibration();
|
||||||
OP_REQUIRES_OK(context, calib_ctx->SerializeToString(&serialized_resource));
|
OP_REQUIRES(context, !serialized_resource.empty(),
|
||||||
resource->calib_ctx_.reset();
|
errors::Unknown("Calibration table is empty."));
|
||||||
|
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
|
@ -837,18 +837,17 @@ Status TRTEngineOp::AllocateCalibrationResources(
|
|||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(ERROR) << "Calibration failed: " << s;
|
LOG(ERROR) << "Calibration failed: " << s;
|
||||||
cres->calibrator_->setDone(); // Ignore further pushes
|
cres->calibrator_->setDone(); // Ignore further pushes
|
||||||
}
|
} else {
|
||||||
|
|
||||||
// Transfer the ownership of the engine to the engine cache, so we can
|
// Transfer the ownership of the engine to the engine cache, so we can
|
||||||
// dump it out during conversion for TF 2.0.
|
// dump it out during conversion for TF 2.0.
|
||||||
mutex_lock lock(this->engine_mutex_);
|
mutex_lock lock(this->engine_mutex_);
|
||||||
cres->SetCalibrationTable();
|
|
||||||
this->calibrator_ = std::move(cres->calibrator_);
|
this->calibrator_ = std::move(cres->calibrator_);
|
||||||
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
|
||||||
cres->engine_->createExecutionContext());
|
cres->engine_->createExecutionContext());
|
||||||
cache_res->cache_.emplace(
|
cache_res->cache_.emplace(
|
||||||
shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
|
shapes, absl::make_unique<EngineContext>(std::move(cres->engine_),
|
||||||
std::move(exec_context)));
|
std::move(exec_context)));
|
||||||
|
}
|
||||||
|
|
||||||
VLOG(1) << "Calibration loop terminated " << this->name();
|
VLOG(1) << "Calibration loop terminated " << this->name();
|
||||||
}));
|
}));
|
||||||
|
@ -184,13 +184,7 @@ class SerializeTRTResource : public OpKernel {
|
|||||||
core::ScopedUnref unref_me(resource);
|
core::ScopedUnref unref_me(resource);
|
||||||
|
|
||||||
// Terminate the calibration if any.
|
// Terminate the calibration if any.
|
||||||
if (resource->calib_ctx_) {
|
if (resource->calib_ctx_) resource->calib_ctx_->TerminateCalibration();
|
||||||
// We don't save the calibration_table for TF 2.0 at the moment, it's used
|
|
||||||
// in 1.x environment.
|
|
||||||
string calibration_table;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, resource->calib_ctx_->SerializeToString(&calibration_table));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Serialize the engines and write them to file.
|
// Serialize the engines and write them to file.
|
||||||
std::unique_ptr<WritableFile> file;
|
std::unique_ptr<WritableFile> file;
|
||||||
|
@ -30,6 +30,26 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
|
|
||||||
|
string CalibrationContext::TerminateCalibration() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
if (terminated_) return calibration_table_;
|
||||||
|
|
||||||
|
TRTInt8Calibrator* raw_calibrator = calibrator_.get();
|
||||||
|
raw_calibrator->waitAndSetDone();
|
||||||
|
terminated_ = true;
|
||||||
|
|
||||||
|
// At this point the calibration thread `thr_` is woken up and can
|
||||||
|
// transfer the ownership of `calibrator_` and `engine_` at any time, so
|
||||||
|
// it's not safe to use `calibrator_` below, but we can still access it
|
||||||
|
// using raw pointer.
|
||||||
|
// TODO(laigd): make TRTEngineOp::AllocateCalibrationResources() a member
|
||||||
|
// function of this class instead.
|
||||||
|
|
||||||
|
thr_->join();
|
||||||
|
calibration_table_ = raw_calibrator->getCalibrationTableAsString();
|
||||||
|
return calibration_table_;
|
||||||
|
}
|
||||||
|
|
||||||
const absl::string_view kTfTrtContainerName = "TF-TRT";
|
const absl::string_view kTfTrtContainerName = "TF-TRT";
|
||||||
|
|
||||||
Logger& TRTEngineCacheResource::GetLogger() {
|
Logger& TRTEngineCacheResource::GetLogger() {
|
||||||
|
@ -142,19 +142,7 @@ struct EngineContext {
|
|||||||
// Contains the context required to build the calibration data.
|
// Contains the context required to build the calibration data.
|
||||||
class CalibrationContext {
|
class CalibrationContext {
|
||||||
public:
|
public:
|
||||||
void SetCalibrationTable() {
|
string TerminateCalibration();
|
||||||
calibration_table_ = calibrator_->getCalibrationTableAsString();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SerializeToString(string* serialized) {
|
|
||||||
calibrator_->waitAndSetDone();
|
|
||||||
thr_->join();
|
|
||||||
*serialized = calibration_table_;
|
|
||||||
if (serialized->empty()) {
|
|
||||||
return errors::Unknown("Calibration table is empty.");
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lookup table for temporary staging areas of input tensors for calibration.
|
// Lookup table for temporary staging areas of input tensors for calibration.
|
||||||
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
|
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
|
||||||
@ -162,12 +150,16 @@ class CalibrationContext {
|
|||||||
// Temporary staging areas for calibration inputs.
|
// Temporary staging areas for calibration inputs.
|
||||||
std::vector<PersistentTensor> device_tensors_;
|
std::vector<PersistentTensor> device_tensors_;
|
||||||
|
|
||||||
string calibration_table_;
|
|
||||||
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
std::unique_ptr<TRTInt8Calibrator> calibrator_;
|
||||||
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
|
||||||
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
|
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
|
||||||
// TODO(sami): Use threadpool threads!
|
// TODO(sami): Use threadpool threads!
|
||||||
std::unique_ptr<std::thread> thr_;
|
std::unique_ptr<std::thread> thr_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutex mu_;
|
||||||
|
bool terminated_ GUARDED_BY(mu_) = false;
|
||||||
|
std::string calibration_table_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName;
|
ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName;
|
||||||
|
@ -100,6 +100,7 @@ class TrtPrecisionMode(object):
|
|||||||
]
|
]
|
||||||
return precisions + [p.lower() for p in precisions]
|
return precisions + [p.lower() for p in precisions]
|
||||||
|
|
||||||
|
|
||||||
# Use a large enough number as the default max_workspace_size for TRT engines,
|
# Use a large enough number as the default max_workspace_size for TRT engines,
|
||||||
# so it can produce reasonable performance results with the default.
|
# so it can produce reasonable performance results with the default.
|
||||||
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
|
||||||
@ -263,8 +264,7 @@ def get_tensorrt_rewriter_config(
|
|||||||
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
|
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
|
||||||
optimizer.parameter_map[
|
optimizer.parameter_map[
|
||||||
"use_calibration"].b = conversion_params.use_calibration
|
"use_calibration"].b = conversion_params.use_calibration
|
||||||
optimizer.parameter_map[
|
optimizer.parameter_map["max_batch_size"].i = conversion_params.max_batch_size
|
||||||
"max_batch_size"].i = conversion_params.max_batch_size
|
|
||||||
optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op
|
optimizer.parameter_map["is_dynamic_op"].b = conversion_params.is_dynamic_op
|
||||||
return rewriter_config_with_trt
|
return rewriter_config_with_trt
|
||||||
|
|
||||||
@ -955,14 +955,20 @@ class TrtGraphConverterV2(object):
|
|||||||
engine_asset_dir = tempfile.mkdtemp()
|
engine_asset_dir = tempfile.mkdtemp()
|
||||||
resource_map = {}
|
resource_map = {}
|
||||||
|
|
||||||
def _serialize_and_track_engine(canonical_engine_name):
|
def _serialize_and_track_engine(node):
|
||||||
"""Serialize TRT engines in the cache and track them."""
|
"""Serialize TRT engines in the cache and track them."""
|
||||||
# Don't dump the same cache twice.
|
# Don't dump the same cache twice.
|
||||||
|
canonical_engine_name = _get_canonical_engine_name(node.name)
|
||||||
if canonical_engine_name in resource_map:
|
if canonical_engine_name in resource_map:
|
||||||
return
|
return
|
||||||
|
|
||||||
filename = os.path.join(engine_asset_dir,
|
filename = os.path.join(engine_asset_dir,
|
||||||
"trt-serialized-engine." + canonical_engine_name)
|
"trt-serialized-engine." + canonical_engine_name)
|
||||||
|
if self._need_calibration:
|
||||||
|
calibration_table = gen_trt_ops.get_calibration_data_op(
|
||||||
|
canonical_engine_name)
|
||||||
|
node.attr["calibration_data"].s = calibration_table.numpy()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
gen_trt_ops.serialize_trt_resource(
|
gen_trt_ops.serialize_trt_resource(
|
||||||
resource_name=canonical_engine_name,
|
resource_name=canonical_engine_name,
|
||||||
@ -980,19 +986,28 @@ class TrtGraphConverterV2(object):
|
|||||||
|
|
||||||
for node in self._converted_graph_def.node:
|
for node in self._converted_graph_def.node:
|
||||||
if node.op == _TRT_ENGINE_OP_NAME:
|
if node.op == _TRT_ENGINE_OP_NAME:
|
||||||
_serialize_and_track_engine(_get_canonical_engine_name(node.name))
|
_serialize_and_track_engine(node)
|
||||||
for func in self._converted_graph_def.library.function:
|
for func in self._converted_graph_def.library.function:
|
||||||
for node in func.node_def:
|
for node in func.node_def:
|
||||||
if node.op == _TRT_ENGINE_OP_NAME:
|
if node.op == _TRT_ENGINE_OP_NAME:
|
||||||
_serialize_and_track_engine(canonical_engine_name(node))
|
_serialize_and_track_engine(node)
|
||||||
|
|
||||||
self._saved_model.trt_engine_resources = resource_map
|
self._saved_model.trt_engine_resources = resource_map
|
||||||
|
|
||||||
|
# Rebuild the function since calibration may change the graph.
|
||||||
|
func_to_save = wrap_function.function_from_graph_def(
|
||||||
|
self._converted_graph_def,
|
||||||
|
[tensor.name for tensor in self._converted_func.inputs],
|
||||||
|
[tensor.name for tensor in self._converted_func.outputs])
|
||||||
|
func_to_save.graph.structured_outputs = nest.pack_sequence_as(
|
||||||
|
self._converted_func.graph.structured_outputs,
|
||||||
|
func_to_save.graph.structured_outputs)
|
||||||
|
|
||||||
# Rewrite the signature map using the optimized ConcreteFunction.
|
# Rewrite the signature map using the optimized ConcreteFunction.
|
||||||
signatures = {
|
signatures = {
|
||||||
key: value for key, value in self._saved_model.signatures.items()
|
key: value for key, value in self._saved_model.signatures.items()
|
||||||
}
|
}
|
||||||
signatures[self._input_saved_model_signature_key] = self._converted_func
|
signatures[self._input_saved_model_signature_key] = func_to_save
|
||||||
save.save(self._saved_model, output_saved_model_dir, signatures)
|
save.save(self._saved_model, output_saved_model_dir, signatures)
|
||||||
|
|
||||||
|
|
||||||
|
@ -322,6 +322,23 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
maximum_cached_engines=2,
|
maximum_cached_engines=2,
|
||||||
max_batch_size=max_batch_size if max_batch_size else 1))
|
max_batch_size=max_batch_size if max_batch_size else 1))
|
||||||
|
|
||||||
|
def _CheckTrtOps(self, concrete_func, check_fn=None):
|
||||||
|
graph_def = concrete_func.graph.as_graph_def()
|
||||||
|
trt_op_names = []
|
||||||
|
for node in graph_def.node:
|
||||||
|
if node.op == "TRTEngineOp":
|
||||||
|
trt_op_names.append(node.name)
|
||||||
|
if check_fn:
|
||||||
|
check_fn(node)
|
||||||
|
for func in graph_def.library.function:
|
||||||
|
for node in func.node_def:
|
||||||
|
if node.op == "TRTEngineOp":
|
||||||
|
trt_op_names.append(node.name)
|
||||||
|
if check_fn:
|
||||||
|
check_fn(node)
|
||||||
|
self.assertEqual(1, len(trt_op_names))
|
||||||
|
self.assertIn("TRTEngineOp_0", trt_op_names[0])
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testTrtGraphConverter_DynamicConversion_v2(self):
|
def testTrtGraphConverter_DynamicConversion_v2(self):
|
||||||
"""Test case for trt_convert.TrtGraphConverter()."""
|
"""Test case for trt_convert.TrtGraphConverter()."""
|
||||||
@ -341,22 +358,11 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
converter = self._CreateConverterV2(input_saved_model_dir)
|
converter = self._CreateConverterV2(input_saved_model_dir)
|
||||||
converted_func = converter.convert()
|
converted_func = converter.convert()
|
||||||
|
|
||||||
def _CheckTrtOps(graph_def):
|
|
||||||
trt_op_names = [
|
|
||||||
node.name for node in graph_def.node if node.op == "TRTEngineOp"
|
|
||||||
]
|
|
||||||
for func in graph_def.library.function:
|
|
||||||
for node in func.node_def:
|
|
||||||
if node.op == "TRTEngineOp":
|
|
||||||
trt_op_names.append(node.name)
|
|
||||||
self.assertEqual(1, len(trt_op_names))
|
|
||||||
self.assertIn("TRTEngineOp_0", trt_op_names[0])
|
|
||||||
|
|
||||||
# Verify the converted GraphDef and ConcreteFunction.
|
# Verify the converted GraphDef and ConcreteFunction.
|
||||||
self.assertIsInstance(converted_func, def_function.Function)
|
self.assertIsInstance(converted_func, def_function.Function)
|
||||||
converted_concrete_func = converted_func.get_concrete_function(
|
converted_concrete_func = converted_func.get_concrete_function(
|
||||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
||||||
_CheckTrtOps(converted_concrete_func.graph.as_graph_def())
|
self._CheckTrtOps(converted_concrete_func)
|
||||||
|
|
||||||
# Save the converted model without any TRT engine cache.
|
# Save the converted model without any TRT engine cache.
|
||||||
output_saved_model_dir = self.mkdtemp()
|
output_saved_model_dir = self.mkdtemp()
|
||||||
@ -382,6 +388,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(os.path.exists(expected_asset_file))
|
self.assertTrue(os.path.exists(expected_asset_file))
|
||||||
self.assertTrue(os.path.getsize(expected_asset_file))
|
self.assertTrue(os.path.getsize(expected_asset_file))
|
||||||
|
|
||||||
|
del converter
|
||||||
|
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||||
|
|
||||||
# Load and verify the converted model.
|
# Load and verify the converted model.
|
||||||
#
|
#
|
||||||
# TODO(laigd): the name of the new input_signature of the
|
# TODO(laigd): the name of the new input_signature of the
|
||||||
@ -390,16 +399,18 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
root_with_trt = load.load(output_saved_model_dir)
|
root_with_trt = load.load(output_saved_model_dir)
|
||||||
# TODO(laigd): `root_with_trt.run` is still using the original graph without
|
# TODO(laigd): `root_with_trt.run` is still using the original graph without
|
||||||
# trt. Consider changing that.
|
# trt. Consider changing that.
|
||||||
# _CheckTrtOps(
|
# self._CheckTrtOps(root_with_trt.run.get_concrete_function())
|
||||||
# root_with_trt.run.get_concrete_function().graph.as_graph_def())
|
|
||||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||||
_CheckTrtOps(converted_signature.graph.as_graph_def())
|
self._CheckTrtOps(converted_signature)
|
||||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||||
# The output of running the converted signature is a dict due to
|
# The output of running the converted signature is a dict due to
|
||||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||||
output_with_trt = output_with_trt.values()[0]
|
output_with_trt = output_with_trt.values()[0]
|
||||||
self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6)
|
self.assertAllClose(expected_output, output_with_trt, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
del root_with_trt
|
||||||
|
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testTrtGraphConverter_StaticConversion_v2(self):
|
def testTrtGraphConverter_StaticConversion_v2(self):
|
||||||
"""Test case for trt_convert.TrtGraphConverter() using static mode."""
|
"""Test case for trt_convert.TrtGraphConverter() using static mode."""
|
||||||
@ -419,23 +430,14 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
converter = self._CreateConverterV2(input_saved_model_dir, max_batch_size=4)
|
converter = self._CreateConverterV2(input_saved_model_dir, max_batch_size=4)
|
||||||
converted_func = converter.convert()
|
converted_func = converter.convert()
|
||||||
|
|
||||||
def _CheckTrtOps(graph_def):
|
def _CheckFn(node):
|
||||||
trt_op_names = [
|
|
||||||
node.name for node in graph_def.node if node.op == "TRTEngineOp"
|
|
||||||
]
|
|
||||||
for func in graph_def.library.function:
|
|
||||||
for node in func.node_def:
|
|
||||||
if node.op == "TRTEngineOp":
|
|
||||||
trt_op_names.append(node.name)
|
|
||||||
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
self.assertTrue(len(node.attr["serialized_segment"].s), node.name)
|
||||||
self.assertEqual(1, len(trt_op_names))
|
|
||||||
self.assertIn("TRTEngineOp_0", trt_op_names[0])
|
|
||||||
|
|
||||||
# Verify the converted GraphDef and ConcreteFunction.
|
# Verify the converted GraphDef and ConcreteFunction.
|
||||||
self.assertIsInstance(converted_func, def_function.Function)
|
self.assertIsInstance(converted_func, def_function.Function)
|
||||||
converted_concrete_func = converted_func.get_concrete_function(
|
converted_concrete_func = converted_func.get_concrete_function(
|
||||||
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
tensor_spec.TensorSpec(shape=[None, 1, 1], dtype=dtypes.float32))
|
||||||
_CheckTrtOps(converted_concrete_func.graph.as_graph_def())
|
self._CheckTrtOps(converted_concrete_func, _CheckFn)
|
||||||
|
|
||||||
# Save the converted model with the statically-built engine inlined.
|
# Save the converted model with the statically-built engine inlined.
|
||||||
output_saved_model_dir = self.mkdtemp()
|
output_saved_model_dir = self.mkdtemp()
|
||||||
@ -444,10 +446,13 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
|
output_saved_model_dir, "assets/trt-serialized-engine.TRTEngineOp_0")
|
||||||
self.assertFalse(os.path.exists(unexpected_asset_file))
|
self.assertFalse(os.path.exists(unexpected_asset_file))
|
||||||
|
|
||||||
|
del converter
|
||||||
|
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||||
|
|
||||||
# Load and verify the converted model.
|
# Load and verify the converted model.
|
||||||
root_with_trt = load.load(output_saved_model_dir)
|
root_with_trt = load.load(output_saved_model_dir)
|
||||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||||
_CheckTrtOps(converted_signature.graph.as_graph_def())
|
self._CheckTrtOps(converted_signature, _CheckFn)
|
||||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||||
# The output of running the converted signature is a dict due to
|
# The output of running the converted signature is a dict due to
|
||||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||||
@ -457,6 +462,9 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
atol=1e-6,
|
atol=1e-6,
|
||||||
rtol=1e-6)
|
rtol=1e-6)
|
||||||
|
|
||||||
|
del root_with_trt
|
||||||
|
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testTrtGraphConverter_Int8Conversion_v2(self):
|
def testTrtGraphConverter_Int8Conversion_v2(self):
|
||||||
if not is_tensorrt_enabled():
|
if not is_tensorrt_enabled():
|
||||||
@ -493,12 +501,18 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(os.path.exists(expected_asset_file))
|
self.assertTrue(os.path.exists(expected_asset_file))
|
||||||
self.assertTrue(os.path.getsize(expected_asset_file))
|
self.assertTrue(os.path.getsize(expected_asset_file))
|
||||||
|
|
||||||
|
del converter
|
||||||
|
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||||
|
|
||||||
|
def _CheckFn(node):
|
||||||
|
self.assertTrue(len(node.attr["calibration_data"].s), node.name)
|
||||||
|
|
||||||
# Load and verify the converted model.
|
# Load and verify the converted model.
|
||||||
root_with_trt = load.load(output_saved_model_dir)
|
root_with_trt = load.load(output_saved_model_dir)
|
||||||
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
converted_signature = root_with_trt.signatures[_SAVED_MODEL_SIGNATURE_KEY]
|
||||||
|
self._CheckTrtOps(converted_signature, _CheckFn)
|
||||||
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
output_with_trt = converted_signature(ops.convert_to_tensor(np_input))
|
||||||
self.assertEqual(1, len(output_with_trt))
|
self.assertEqual(1, len(output_with_trt))
|
||||||
|
|
||||||
# The output of running the converted signature is a dict due to
|
# The output of running the converted signature is a dict due to
|
||||||
# compatibility reasons with V1 SavedModel signature mechanism.
|
# compatibility reasons with V1 SavedModel signature mechanism.
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
@ -507,6 +521,14 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
|
|||||||
atol=1e-6,
|
atol=1e-6,
|
||||||
rtol=1e-6)
|
rtol=1e-6)
|
||||||
|
|
||||||
|
# Run with an input of different batch size. It should build a new engine
|
||||||
|
# using calibration table.
|
||||||
|
np_input = np.random.random_sample([5, 1, 1]).astype(np.float32)
|
||||||
|
converted_signature(ops.convert_to_tensor(np_input))
|
||||||
|
|
||||||
|
del root_with_trt
|
||||||
|
gc.collect() # Force GC to destroy the TRT engine cache.
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testTrtGraphConverter_DestroyEngineCache(self):
|
def testTrtGraphConverter_DestroyEngineCache(self):
|
||||||
"""Test case for trt_convert.TrtGraphConverter()."""
|
"""Test case for trt_convert.TrtGraphConverter()."""
|
||||||
|
Loading…
Reference in New Issue
Block a user