[TF:TRT] Pass the new parameters from TrtGraphConverterV2 to the TF-TRT bridge.
Previously, we have a private field in TrtGraphConverterV2 to support the testing of dynamic shape mode. Replace this private field with the new public field use_dynamic_shape. Pass the profile_strategy field from the Python API to the TF-TRT grappler pass. The TF-TRT bridge records this information in the TRTEngineOp. Modify the TF-TRT Python tests to use the new parameters. Also verify the profile_strategy information for dynamic shape tests. PiperOrigin-RevId: 360295953 Change-Id: Ifbb0fb79e829e98db6e28810881f9d8e686617f1
This commit is contained in:
parent
23466d6a7b
commit
3ea3a13eb0
@ -580,6 +580,7 @@ cc_library(
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
|
@ -481,22 +481,27 @@ Status CreateTRTNode(const ConversionParams& params,
|
||||
NodeDef trt_node;
|
||||
NameAttrList function;
|
||||
function.set_name(StrCat(info.engine_name, "_native_segment"));
|
||||
Status status =
|
||||
node_builder.Attr("input_shapes", input_shape_protos)
|
||||
.Attr("static_engine",
|
||||
info.engine_type == EngineInfo::EngineType::TRTStatic)
|
||||
.Attr("segment_func", function)
|
||||
.Attr("serialized_segment", segment_string)
|
||||
.Attr("calibration_data", "")
|
||||
.Attr("max_cached_engines_count", info.maximum_cached_engines)
|
||||
.Attr("workspace_size_bytes", info.max_workspace_size_bytes)
|
||||
.Attr("max_batch_size", max_batch_size)
|
||||
.Attr("precision_mode", prec_string)
|
||||
.Attr("use_calibration", info.use_calibration)
|
||||
.Attr("_use_implicit_batch", params.use_implicit_batch)
|
||||
.Attr("_allow_build_at_runtime", info.allow_build_at_runtime)
|
||||
.Attr("OutT", out_types)
|
||||
.Finalize(&trt_node);
|
||||
node_builder.Attr("input_shapes", input_shape_protos)
|
||||
.Attr("static_engine",
|
||||
info.engine_type == EngineInfo::EngineType::TRTStatic)
|
||||
.Attr("segment_func", function)
|
||||
.Attr("serialized_segment", segment_string)
|
||||
.Attr("calibration_data", "")
|
||||
.Attr("max_cached_engines_count", info.maximum_cached_engines)
|
||||
.Attr("workspace_size_bytes", info.max_workspace_size_bytes)
|
||||
.Attr("max_batch_size", max_batch_size)
|
||||
.Attr("precision_mode", prec_string)
|
||||
.Attr("use_calibration", info.use_calibration)
|
||||
.Attr("_use_implicit_batch", params.use_implicit_batch)
|
||||
.Attr("_allow_build_at_runtime", info.allow_build_at_runtime)
|
||||
.Attr("OutT", out_types);
|
||||
|
||||
if (!params.use_implicit_batch) {
|
||||
node_builder.Attr("profile_strategy",
|
||||
ProfileStrategyToName(params.profile_strategy));
|
||||
}
|
||||
|
||||
Status status = node_builder.Finalize(&trt_node);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Node construction failed with" << status;
|
||||
return status;
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
@ -46,6 +47,7 @@ struct ConversionParams {
|
||||
int max_cached_engines = 1;
|
||||
bool use_calibration = true;
|
||||
bool use_implicit_batch = true;
|
||||
ProfileStrategy profile_strategy = ProfileStrategy::kRange;
|
||||
bool allow_build_at_runtime = true;
|
||||
};
|
||||
|
||||
|
@ -76,6 +76,10 @@ Status TRTOptimizationPass::Init(
|
||||
if (params.count("use_implicit_batch")) {
|
||||
use_implicit_batch_ = params.at("use_implicit_batch").b();
|
||||
}
|
||||
if (params.count("profile_strategy")) {
|
||||
TF_RETURN_IF_ERROR(ProfileStrategyFromName(
|
||||
params.at("profile_strategy").s(), &profile_strategy_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -242,6 +246,7 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster,
|
||||
cp.max_cached_engines = max_cached_batches_;
|
||||
cp.use_calibration = use_calibration_;
|
||||
cp.use_implicit_batch = use_implicit_batch_;
|
||||
cp.profile_strategy = profile_strategy_;
|
||||
cp.allow_build_at_runtime = allow_build_at_runtime_;
|
||||
auto status = ConvertAfterShapes(cp);
|
||||
VLOG(1) << "Returning from " << name_;
|
||||
|
@ -42,6 +42,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
||||
max_workspace_size_bytes_(256LL << 20),
|
||||
use_calibration_(true),
|
||||
use_implicit_batch_(true),
|
||||
profile_strategy_(ProfileStrategy::kRange),
|
||||
allow_build_at_runtime_(true) {
|
||||
VLOG(1) << "Constructing " << name_;
|
||||
}
|
||||
@ -75,6 +76,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer {
|
||||
int64_t max_workspace_size_bytes_;
|
||||
bool use_calibration_;
|
||||
bool use_implicit_batch_;
|
||||
ProfileStrategy profile_strategy_;
|
||||
bool allow_build_at_runtime_;
|
||||
};
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
@ -257,6 +258,35 @@ int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
|
||||
return n_input / n_profiles;
|
||||
}
|
||||
|
||||
string ProfileStrategyToName(const ProfileStrategy strategy) {
|
||||
switch (strategy) {
|
||||
case ProfileStrategy::kRange:
|
||||
return "Range";
|
||||
case ProfileStrategy::kOptimal:
|
||||
return "Optimal";
|
||||
case ProfileStrategy::kRangeOptimal:
|
||||
return "Range+Optimal";
|
||||
case ProfileStrategy::kImplicitBatchModeCompatible:
|
||||
return "ImplicitBatchModeCompatible";
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy) {
|
||||
if (name == "range") {
|
||||
*strategy = ProfileStrategy::kRange;
|
||||
} else if (name == "optimal") {
|
||||
*strategy = ProfileStrategy::kOptimal;
|
||||
} else if (name == "range+optimal") {
|
||||
*strategy = ProfileStrategy::kRangeOptimal;
|
||||
} else if (name == "implicitbatchmodecompatible") {
|
||||
*strategy = ProfileStrategy::kImplicitBatchModeCompatible;
|
||||
} else {
|
||||
return errors::InvalidArgument("Invalid profile strategy: ", name);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
absl::string_view GetDeviceName(const Node* node) {
|
||||
|
@ -200,6 +200,17 @@ absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
|
||||
absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
|
||||
const DeviceNameUtils::ParsedName& a, absl::string_view b);
|
||||
|
||||
// Optimization profile generation strategies.
|
||||
enum class ProfileStrategy {
|
||||
kRange,
|
||||
kOptimal,
|
||||
kRangeOptimal,
|
||||
kImplicitBatchModeCompatible,
|
||||
};
|
||||
|
||||
string ProfileStrategyToName(const ProfileStrategy strategy);
|
||||
Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy);
|
||||
|
||||
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
|
||||
} // namespace tensorrt
|
||||
|
@ -56,7 +56,8 @@ REGISTER_OP("TRTEngineOp")
|
||||
.Attr("cached_engine_batches: list(int) >= 0 = []")
|
||||
.Attr("fixed_input_size: bool = true")
|
||||
.Attr("output_shapes: list(shape) = []")
|
||||
.Attr("static_engine: bool = true");
|
||||
.Attr("static_engine: bool = true")
|
||||
.Attr("profile_strategy: string = ''");
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
|
@ -109,6 +109,10 @@ void TrtShapeOptimizationProfile::InitProfiles(
|
||||
VLOG(1) << "Creating profiles with ImplicitBatchModeCompatible strategy";
|
||||
ImplicitBatchModeCompatibleStrategy();
|
||||
break;
|
||||
// Treat all other strategies the same as kOptimal for now. Implementing
|
||||
// those is outlined in the dynamic shape support implementation plan.
|
||||
case ProfileStrategy::kRange:
|
||||
case ProfileStrategy::kRangeOptimal:
|
||||
case ProfileStrategy::kOptimal:
|
||||
VLOG(1) << "Creating profiles with Optimal strategy";
|
||||
OptimalStrategy();
|
||||
|
@ -115,12 +115,6 @@ struct OptimizationProfileConfig {
|
||||
}
|
||||
};
|
||||
|
||||
// Optimization profile generation strategies.
|
||||
enum class ProfileStrategy {
|
||||
kImplicitBatchModeCompatible,
|
||||
kOptimal,
|
||||
};
|
||||
|
||||
// Manages Optimization profiles during TRT Engine construction.
|
||||
//
|
||||
// An optimization profile describes a range of dimensions for each TRT network
|
||||
|
@ -163,6 +163,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
self._trt_test_params = None
|
||||
self._disable_non_trt_optimizers = False
|
||||
self._use_implicit_batch = True
|
||||
self._profile_strategy = "Unknown"
|
||||
|
||||
def setUp(self):
|
||||
"""Setup method."""
|
||||
@ -264,8 +265,9 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
def DisableNonTrtOptimizers(self):
|
||||
self._disable_non_trt_optimizers = True
|
||||
|
||||
def DisableImplicitBatchMode(self):
|
||||
def SetDynamicShapeModeAndProfileStrategy(self, profile_strategy="Range"):
|
||||
self._use_implicit_batch = False
|
||||
self._profile_strategy = profile_strategy
|
||||
|
||||
def GetParams(self):
|
||||
"""Returns a TfTrtIntegrationTestParams for the test."""
|
||||
@ -453,11 +455,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
if run_params.is_v2:
|
||||
converter_v2 = trt_convert.TrtGraphConverterV2(
|
||||
input_saved_model_dir=saved_model_dir,
|
||||
conversion_params=conversion_params)
|
||||
conversion_params=conversion_params,
|
||||
use_dynamic_shape=not self._use_implicit_batch,
|
||||
dynamic_shape_profile_strategy=self._profile_strategy)
|
||||
if self._disable_non_trt_optimizers:
|
||||
converter_v2._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access
|
||||
if not self._use_implicit_batch:
|
||||
converter_v2._test_only_use_implicit_batch = False # pylint: disable=protected-access
|
||||
return converter_v2
|
||||
|
||||
converter_v1 = trt_convert.TrtGraphConverter(
|
||||
@ -873,6 +875,10 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
|
||||
all_op_names.append(node.name)
|
||||
if node.op == "TRTEngineOp":
|
||||
trt_op_names.append(node.name)
|
||||
if not self._use_implicit_batch:
|
||||
self.assertEqual(
|
||||
self._ToString(node.attr["profile_strategy"].s).lower(),
|
||||
self._profile_strategy.lower())
|
||||
|
||||
all_op_names = self._Canonicalize(all_op_names)
|
||||
trt_op_names = self._RemoveGraphSequenceNumber(
|
||||
|
@ -125,8 +125,8 @@ class ExplicitBatchTest(TrtModeTestBase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Diable implicit batch mode for testing explicit batch mode.
|
||||
self.DisableImplicitBatchMode()
|
||||
self.SetDynamicShapeModeAndProfileStrategy(
|
||||
profile_strategy="ImplicitBatchModeCompatible")
|
||||
|
||||
|
||||
class DynamicShapesTest(TrtModeTestBase):
|
||||
@ -162,7 +162,9 @@ class DynamicShapesTest(TrtModeTestBase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.DisableImplicitBatchMode()
|
||||
self.SetDynamicShapeModeAndProfileStrategy(
|
||||
profile_strategy="ImplicitBatchModeCompatible")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -316,8 +316,11 @@ def _get_tensorrt_rewriter_config(conversion_params,
|
||||
if max_batch_size is not None:
|
||||
optimizer.parameter_map["max_batch_size"].i = max_batch_size
|
||||
optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
|
||||
# While we accept case insensitive strings from the users, we only pass the
|
||||
# strings in lower cases to TF-TRT converter.
|
||||
if not use_implicit_batch:
|
||||
optimizer.parameter_map["profile_strategy"].s = _to_bytes(profile_strategy)
|
||||
optimizer.parameter_map["profile_strategy"].s = _to_bytes(
|
||||
profile_strategy.lower())
|
||||
|
||||
# Disabling optimizers should happen after defining the TF-TRT grappler pass
|
||||
# otherwise the template can overwrite the disablement.
|
||||
@ -1013,6 +1016,7 @@ class TrtGraphConverterV2(object):
|
||||
else:
|
||||
self._use_dynamic_shape = use_dynamic_shape
|
||||
|
||||
self._profile_strategy = "Unknown"
|
||||
if self._use_dynamic_shape:
|
||||
if dynamic_shape_profile_strategy is None:
|
||||
self._profile_strategy = PROFILE_STRATEGY_RANGE
|
||||
@ -1022,10 +1026,9 @@ class TrtGraphConverterV2(object):
|
||||
|
||||
# Fields to support TF-TRT testing and shouldn't be used for other purpose.
|
||||
self._test_only_disable_non_trt_optimizers = False
|
||||
self._test_only_use_implicit_batch = True
|
||||
|
||||
def _need_trt_profiles(self):
|
||||
return not self._test_only_use_implicit_batch
|
||||
return self._use_dynamic_shape
|
||||
|
||||
def _run_conversion(self, meta_graph_def):
|
||||
"""Run Grappler's OptimizeGraph() tool to convert the graph.
|
||||
@ -1042,7 +1045,8 @@ class TrtGraphConverterV2(object):
|
||||
is_dynamic_op=True,
|
||||
max_batch_size=None,
|
||||
disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers,
|
||||
use_implicit_batch=self._test_only_use_implicit_batch)
|
||||
use_implicit_batch=not self._use_dynamic_shape,
|
||||
profile_strategy=self._profile_strategy)
|
||||
grappler_session_config.graph_options.rewrite_options.CopyFrom(
|
||||
custom_rewriter_config)
|
||||
return tf_optimizer.OptimizeGraph(
|
||||
|
Loading…
x
Reference in New Issue
Block a user