From 4d282aceb0f289fd882cae8e73f9dc114ecfefb2 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 7 Oct 2019 09:31:03 -0700 Subject: [PATCH] Set `disallow_output_partition_graphs` when creating a SavedModelBundleLite. This removes the need to store a copy of the rewritten subgraph(s) used in the session, which can save a substantial amount of RAM. PiperOrigin-RevId: 273309216 --- tensorflow/cc/saved_model/loader.cc | 7 +++++++ .../cc/saved_model/saved_model_bundle_lite_test.cc | 14 ++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index df36245dc95..b8834b3532b 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -440,8 +440,15 @@ Status LoadSavedModel(const SessionOptions& session_options, SavedModelBundleLite* const bundle) { SavedModelBundle legacy_bundle; SessionOptions rewritten_options(session_options); + // We disallow calls to Session::Extend() on the returned session, so we can + // reduce memory consumption by not storing the original GraphDef. rewritten_options.config.mutable_experimental() ->set_optimize_for_static_graph(true); + // Disallowing the `RunOptions.output_partition_graphs` option (typically used + // in debugging and tests) allows us to reduce memory consumption further by + // not storing the rewritten subgraph for each signature. + rewritten_options.config.mutable_experimental() + ->set_disable_output_partition_graphs(true); // TODO(mrry): Consider specializing the session creation to reduce peak // RAM consumption by using `Session::Create(GraphDef&&)`. TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir, diff --git a/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc index 1ed8cab773a..604fc412800 100644 --- a/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc +++ b/tensorflow/cc/saved_model/saved_model_bundle_lite_test.cc @@ -13,19 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/cc/saved_model/loader.h" - #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/cc/saved_model/signature_constants.h" #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { namespace { @@ -92,6 +93,15 @@ class LoaderTest : public ::testing::Test { test::ExpectTensorEqual( outputs[0], test::AsTensor({2, 2.5, 3, 3.5}, TensorShape({4, 1}))); + + // Validate the `output_partition_graphs` is not supported. + RunOptions run_options; + run_options.set_output_partition_graphs(true); + RunMetadata run_metadata; + Status s = + bundle.GetSession()->Run(run_options, {{input_name, input}}, + {output_name}, {}, &outputs, &run_metadata); + ASSERT_TRUE(errors::IsInvalidArgument(s)); } };