Set disallow_output_partition_graphs
when creating a SavedModelBundleLite.
This removes the need to store a copy of the rewritten subgraph(s) used in the session, which can save a substantial amount of RAM. PiperOrigin-RevId: 273309216
This commit is contained in:
parent
ebb8264f85
commit
4d282aceb0
@ -440,8 +440,15 @@ Status LoadSavedModel(const SessionOptions& session_options,
|
||||
SavedModelBundleLite* const bundle) {
|
||||
SavedModelBundle legacy_bundle;
|
||||
SessionOptions rewritten_options(session_options);
|
||||
// We disallow calls to Session::Extend() on the returned session, so we can
|
||||
// reduce memory consumption by not storing the original GraphDef.
|
||||
rewritten_options.config.mutable_experimental()
|
||||
->set_optimize_for_static_graph(true);
|
||||
// Disallowing the `RunOptions.output_partition_graphs` option (typically used
|
||||
// in debugging and tests) allows us to reduce memory consumption further by
|
||||
// not storing the rewritten subgraph for each signature.
|
||||
rewritten_options.config.mutable_experimental()
|
||||
->set_disable_output_partition_graphs(true);
|
||||
// TODO(mrry): Consider specializing the session creation to reduce peak
|
||||
// RAM consumption by using `Session::Create(GraphDef&&)`.
|
||||
TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir,
|
||||
|
@ -13,19 +13,20 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#include "tensorflow/cc/saved_model/signature_constants.h"
|
||||
#include "tensorflow/cc/saved_model/tag_constants.h"
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/example/feature.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -92,6 +93,15 @@ class LoaderTest : public ::testing::Test {
|
||||
test::ExpectTensorEqual<float>(
|
||||
outputs[0],
|
||||
test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
|
||||
|
||||
// Validate the `output_partition_graphs` is not supported.
|
||||
RunOptions run_options;
|
||||
run_options.set_output_partition_graphs(true);
|
||||
RunMetadata run_metadata;
|
||||
Status s =
|
||||
bundle.GetSession()->Run(run_options, {{input_name, input}},
|
||||
{output_name}, {}, &outputs, &run_metadata);
|
||||
ASSERT_TRUE(errors::IsInvalidArgument(s));
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user