Set disallow_output_partition_graphs when creating a SavedModelBundleLite.

This removes the need to store a copy of the rewritten subgraph(s) used in the session, which can save a substantial amount of RAM.

PiperOrigin-RevId: 273309216
This commit is contained in:
Derek Murray 2019-10-07 09:31:03 -07:00 committed by TensorFlower Gardener
parent ebb8264f85
commit 4d282aceb0
2 changed files with 19 additions and 2 deletions

View File

@ -440,8 +440,15 @@ Status LoadSavedModel(const SessionOptions& session_options,
SavedModelBundleLite* const bundle) {
SavedModelBundle legacy_bundle;
SessionOptions rewritten_options(session_options);
// We disallow calls to Session::Extend() on the returned session, so we can
// reduce memory consumption by not storing the original GraphDef.
rewritten_options.config.mutable_experimental()
->set_optimize_for_static_graph(true);
// Disallowing the `RunOptions.output_partition_graphs` option (typically used
// in debugging and tests) allows us to reduce memory consumption further by
// not storing the rewritten subgraph for each signature.
rewritten_options.config.mutable_experimental()
->set_disable_output_partition_graphs(true);
// TODO(mrry): Consider specializing the session creation to reduce peak
// RAM consumption by using `Session::Create(GraphDef&&)`.
TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir,

View File

@ -13,19 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
namespace {
@ -92,6 +93,15 @@ class LoaderTest : public ::testing::Test {
test::ExpectTensorEqual<float>(
outputs[0],
test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
// Validate the `output_partition_graphs` is not supported.
RunOptions run_options;
run_options.set_output_partition_graphs(true);
RunMetadata run_metadata;
Status s =
bundle.GetSession()->Run(run_options, {{input_name, input}},
{output_name}, {}, &outputs, &run_metadata);
ASSERT_TRUE(errors::IsInvalidArgument(s));
}
};