diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 64ac229d905..f68eee7e813 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -580,6 +580,7 @@ cc_library( copts = tf_copts(), deps = [ "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index e7e5cef8b86..e2fa7f30873 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -481,22 +481,27 @@ Status CreateTRTNode(const ConversionParams& params, NodeDef trt_node; NameAttrList function; function.set_name(StrCat(info.engine_name, "_native_segment")); - Status status = - node_builder.Attr("input_shapes", input_shape_protos) - .Attr("static_engine", - info.engine_type == EngineInfo::EngineType::TRTStatic) - .Attr("segment_func", function) - .Attr("serialized_segment", segment_string) - .Attr("calibration_data", "") - .Attr("max_cached_engines_count", info.maximum_cached_engines) - .Attr("workspace_size_bytes", info.max_workspace_size_bytes) - .Attr("max_batch_size", max_batch_size) - .Attr("precision_mode", prec_string) - .Attr("use_calibration", info.use_calibration) - .Attr("_use_implicit_batch", params.use_implicit_batch) - .Attr("_allow_build_at_runtime", info.allow_build_at_runtime) - .Attr("OutT", out_types) - .Finalize(&trt_node); + node_builder.Attr("input_shapes", input_shape_protos) + .Attr("static_engine", + info.engine_type == EngineInfo::EngineType::TRTStatic) + .Attr("segment_func", function) + .Attr("serialized_segment", segment_string) + .Attr("calibration_data", "") + .Attr("max_cached_engines_count", info.maximum_cached_engines) + .Attr("workspace_size_bytes", info.max_workspace_size_bytes) + .Attr("max_batch_size", max_batch_size) + .Attr("precision_mode", prec_string) + .Attr("use_calibration", info.use_calibration) + .Attr("_use_implicit_batch", params.use_implicit_batch) + .Attr("_allow_build_at_runtime", info.allow_build_at_runtime) + .Attr("OutT", out_types); + + if (!params.use_implicit_batch) { + node_builder.Attr("profile_strategy", + ProfileStrategyToName(params.profile_strategy)); + } + + Status status = node_builder.Finalize(&trt_node); if (!status.ok()) { LOG(ERROR) << "Node construction failed with" << status; return status; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index d3897e864fa..43a551e01bc 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -46,6 +47,7 @@ struct ConversionParams { int max_cached_engines = 1; bool use_calibration = true; bool use_implicit_batch = true; + ProfileStrategy profile_strategy = ProfileStrategy::kRange; bool allow_build_at_runtime = true; }; diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 12fea3ade40..324ba0cf682 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -76,6 +76,10 @@ Status TRTOptimizationPass::Init( if (params.count("use_implicit_batch")) { use_implicit_batch_ = params.at("use_implicit_batch").b(); } + if (params.count("profile_strategy")) { + TF_RETURN_IF_ERROR(ProfileStrategyFromName( + params.at("profile_strategy").s(), &profile_strategy_)); + } return Status::OK(); } @@ -242,6 +246,7 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, cp.max_cached_engines = max_cached_batches_; cp.use_calibration = use_calibration_; cp.use_implicit_batch = use_implicit_batch_; + cp.profile_strategy = profile_strategy_; cp.allow_build_at_runtime = allow_build_at_runtime_; auto status = ConvertAfterShapes(cp); VLOG(1) << "Returning from " << name_; diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index e0aaa5500ab..fd984e5772c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -42,6 +42,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { max_workspace_size_bytes_(256LL << 20), use_calibration_(true), use_implicit_batch_(true), + profile_strategy_(ProfileStrategy::kRange), allow_build_at_runtime_(true) { VLOG(1) << "Constructing " << name_; } @@ -75,6 +76,7 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { int64_t max_workspace_size_bytes_; bool use_calibration_; bool use_implicit_batch_; + ProfileStrategy profile_strategy_; bool allow_build_at_runtime_; }; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index a4bd8d5afd9..34cbaf9a15e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "absl/strings/ascii.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" @@ -257,6 +258,35 @@ int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) { return n_input / n_profiles; } +string ProfileStrategyToName(const ProfileStrategy strategy) { + switch (strategy) { + case ProfileStrategy::kRange: + return "Range"; + case ProfileStrategy::kOptimal: + return "Optimal"; + case ProfileStrategy::kRangeOptimal: + return "Range+Optimal"; + case ProfileStrategy::kImplicitBatchModeCompatible: + return "ImplicitBatchModeCompatible"; + } + return "Unknown"; +} + +Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy) { + if (name == "range") { + *strategy = ProfileStrategy::kRange; + } else if (name == "optimal") { + *strategy = ProfileStrategy::kOptimal; + } else if (name == "range+optimal") { + *strategy = ProfileStrategy::kRangeOptimal; + } else if (name == "implicitbatchmodecompatible") { + *strategy = ProfileStrategy::kImplicitBatchModeCompatible; + } else { + return errors::InvalidArgument("Invalid profile strategy: ", name); + } + return Status::OK(); +} + #endif absl::string_view GetDeviceName(const Node* node) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 29b32b6514d..f9cb293a3db 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -200,6 +200,17 @@ absl::optional MergeIfCompatible( absl::optional MergeIfCompatible( const DeviceNameUtils::ParsedName& a, absl::string_view b); +// Optimization profile generation strategies. +enum class ProfileStrategy { + kRange, + kOptimal, + kRangeOptimal, + kImplicitBatchModeCompatible, +}; + +string ProfileStrategyToName(const ProfileStrategy strategy); +Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy); + #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index a2a41f5a03c..22eebdcf884 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -56,7 +56,8 @@ REGISTER_OP("TRTEngineOp") .Attr("cached_engine_batches: list(int) >= 0 = []") .Attr("fixed_input_size: bool = true") .Attr("output_shapes: list(shape) = []") - .Attr("static_engine: bool = true"); + .Attr("static_engine: bool = true") + .Attr("profile_strategy: string = ''"); } // namespace tensorflow #endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc index ca25a5840f5..9b3bf6b5acc 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc @@ -109,6 +109,10 @@ void TrtShapeOptimizationProfile::InitProfiles( VLOG(1) << "Creating profiles with ImplicitBatchModeCompatible strategy"; ImplicitBatchModeCompatibleStrategy(); break; + // Treat all other strategies the same as kOptimal for now. Implementing + // those is outlined in the dynamic shape support implementation plan. + case ProfileStrategy::kRange: + case ProfileStrategy::kRangeOptimal: case ProfileStrategy::kOptimal: VLOG(1) << "Creating profiles with Optimal strategy"; OptimalStrategy(); diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h index 71d7d8b1667..854b94dbdd7 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h @@ -115,12 +115,6 @@ struct OptimizationProfileConfig { } }; -// Optimization profile generation strategies. -enum class ProfileStrategy { - kImplicitBatchModeCompatible, - kOptimal, -}; - // Manages Optimization profiles during TRT Engine construction. // // An optimization profile describes a range of dimensions for each TRT network diff --git a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py index ec9157d20f8..8af00e6860b 100644 --- a/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py @@ -163,6 +163,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): self._trt_test_params = None self._disable_non_trt_optimizers = False self._use_implicit_batch = True + self._profile_strategy = "Unknown" def setUp(self): """Setup method.""" @@ -264,8 +265,9 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def DisableNonTrtOptimizers(self): self._disable_non_trt_optimizers = True - def DisableImplicitBatchMode(self): + def SetDynamicShapeModeAndProfileStrategy(self, profile_strategy="Range"): self._use_implicit_batch = False + self._profile_strategy = profile_strategy def GetParams(self): """Returns a TfTrtIntegrationTestParams for the test.""" @@ -453,11 +455,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): if run_params.is_v2: converter_v2 = trt_convert.TrtGraphConverterV2( input_saved_model_dir=saved_model_dir, - conversion_params=conversion_params) + conversion_params=conversion_params, + use_dynamic_shape=not self._use_implicit_batch, + dynamic_shape_profile_strategy=self._profile_strategy) if self._disable_non_trt_optimizers: converter_v2._test_only_disable_non_trt_optimizers = True # pylint: disable=protected-access - if not self._use_implicit_batch: - converter_v2._test_only_use_implicit_batch = False # pylint: disable=protected-access return converter_v2 converter_v1 = trt_convert.TrtGraphConverter( @@ -873,6 +875,10 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): all_op_names.append(node.name) if node.op == "TRTEngineOp": trt_op_names.append(node.name) + if not self._use_implicit_batch: + self.assertEqual( + self._ToString(node.attr["profile_strategy"].s).lower(), + self._profile_strategy.lower()) all_op_names = self._Canonicalize(all_op_names) trt_op_names = self._RemoveGraphSequenceNumber( diff --git a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py index f5eb3c75653..b96d9b3b586 100644 --- a/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py +++ b/tensorflow/python/compiler/tensorrt/test/trt_mode_test.py @@ -125,8 +125,8 @@ class ExplicitBatchTest(TrtModeTestBase): def setUp(self): super().setUp() - # Diable implicit batch mode for testing explicit batch mode. - self.DisableImplicitBatchMode() + self.SetDynamicShapeModeAndProfileStrategy( + profile_strategy="ImplicitBatchModeCompatible") class DynamicShapesTest(TrtModeTestBase): @@ -162,7 +162,9 @@ class DynamicShapesTest(TrtModeTestBase): def setUp(self): super().setUp() - self.DisableImplicitBatchMode() + self.SetDynamicShapeModeAndProfileStrategy( + profile_strategy="ImplicitBatchModeCompatible") + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py index c930c49fa84..a16e9ad429c 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert.py @@ -316,8 +316,11 @@ def _get_tensorrt_rewriter_config(conversion_params, if max_batch_size is not None: optimizer.parameter_map["max_batch_size"].i = max_batch_size optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch + # While we accept case insensitive strings from the users, we only pass the + # strings in lower cases to TF-TRT converter. if not use_implicit_batch: - optimizer.parameter_map["profile_strategy"].s = _to_bytes(profile_strategy) + optimizer.parameter_map["profile_strategy"].s = _to_bytes( + profile_strategy.lower()) # Disabling optimizers should happen after defining the TF-TRT grappler pass # otherwise the template can overwrite the disablement. @@ -1013,6 +1016,7 @@ class TrtGraphConverterV2(object): else: self._use_dynamic_shape = use_dynamic_shape + self._profile_strategy = "Unknown" if self._use_dynamic_shape: if dynamic_shape_profile_strategy is None: self._profile_strategy = PROFILE_STRATEGY_RANGE @@ -1022,10 +1026,9 @@ class TrtGraphConverterV2(object): # Fields to support TF-TRT testing and shouldn't be used for other purpose. self._test_only_disable_non_trt_optimizers = False - self._test_only_use_implicit_batch = True def _need_trt_profiles(self): - return not self._test_only_use_implicit_batch + return self._use_dynamic_shape def _run_conversion(self, meta_graph_def): """Run Grappler's OptimizeGraph() tool to convert the graph. @@ -1042,7 +1045,8 @@ class TrtGraphConverterV2(object): is_dynamic_op=True, max_batch_size=None, disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers, - use_implicit_batch=self._test_only_use_implicit_batch) + use_implicit_batch=not self._use_dynamic_shape, + profile_strategy=self._profile_strategy) grappler_session_config.graph_options.rewrite_options.CopyFrom( custom_rewriter_config) return tf_optimizer.OptimizeGraph(