Set disallow_output_partition_graphs
when creating a SavedModelBundleLite.
This removes the need to store a copy of the rewritten subgraph(s) used in the session, which can save a substantial amount of RAM. PiperOrigin-RevId: 273309216
This commit is contained in:
parent
ebb8264f85
commit
4d282aceb0
@ -440,8 +440,15 @@ Status LoadSavedModel(const SessionOptions& session_options,
|
|||||||
SavedModelBundleLite* const bundle) {
|
SavedModelBundleLite* const bundle) {
|
||||||
SavedModelBundle legacy_bundle;
|
SavedModelBundle legacy_bundle;
|
||||||
SessionOptions rewritten_options(session_options);
|
SessionOptions rewritten_options(session_options);
|
||||||
|
// We disallow calls to Session::Extend() on the returned session, so we can
|
||||||
|
// reduce memory consumption by not storing the original GraphDef.
|
||||||
rewritten_options.config.mutable_experimental()
|
rewritten_options.config.mutable_experimental()
|
||||||
->set_optimize_for_static_graph(true);
|
->set_optimize_for_static_graph(true);
|
||||||
|
// Disallowing the `RunOptions.output_partition_graphs` option (typically used
|
||||||
|
// in debugging and tests) allows us to reduce memory consumption further by
|
||||||
|
// not storing the rewritten subgraph for each signature.
|
||||||
|
rewritten_options.config.mutable_experimental()
|
||||||
|
->set_disable_output_partition_graphs(true);
|
||||||
// TODO(mrry): Consider specializing the session creation to reduce peak
|
// TODO(mrry): Consider specializing the session creation to reduce peak
|
||||||
// RAM consumption by using `Session::Create(GraphDef&&)`.
|
// RAM consumption by using `Session::Create(GraphDef&&)`.
|
||||||
TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir,
|
TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir,
|
||||||
|
@ -13,19 +13,20 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/cc/saved_model/loader.h"
|
|
||||||
|
|
||||||
#include "tensorflow/cc/saved_model/constants.h"
|
#include "tensorflow/cc/saved_model/constants.h"
|
||||||
|
#include "tensorflow/cc/saved_model/loader.h"
|
||||||
#include "tensorflow/cc/saved_model/signature_constants.h"
|
#include "tensorflow/cc/saved_model/signature_constants.h"
|
||||||
#include "tensorflow/cc/saved_model/tag_constants.h"
|
#include "tensorflow/cc/saved_model/tag_constants.h"
|
||||||
#include "tensorflow/core/example/example.pb.h"
|
#include "tensorflow/core/example/example.pb.h"
|
||||||
#include "tensorflow/core/example/feature.pb.h"
|
#include "tensorflow/core/example/feature.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -92,6 +93,15 @@ class LoaderTest : public ::testing::Test {
|
|||||||
test::ExpectTensorEqual<float>(
|
test::ExpectTensorEqual<float>(
|
||||||
outputs[0],
|
outputs[0],
|
||||||
test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
|
test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
|
||||||
|
|
||||||
|
// Validate the `output_partition_graphs` is not supported.
|
||||||
|
RunOptions run_options;
|
||||||
|
run_options.set_output_partition_graphs(true);
|
||||||
|
RunMetadata run_metadata;
|
||||||
|
Status s =
|
||||||
|
bundle.GetSession()->Run(run_options, {{input_name, input}},
|
||||||
|
{output_name}, {}, &outputs, &run_metadata);
|
||||||
|
ASSERT_TRUE(errors::IsInvalidArgument(s));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user