diff --git a/tensorflow/python/distribute/multi_process_lib.py b/tensorflow/python/distribute/multi_process_lib.py index 12b81db7189..9b7851439b7 100644 --- a/tensorflow/python/distribute/multi_process_lib.py +++ b/tensorflow/python/distribute/multi_process_lib.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import multiprocessing as _multiprocessing +import os import unittest from tensorflow.python.platform import test @@ -42,6 +43,7 @@ class Process(object): def test_main(): """Main function to be called within `__main__` of a test file.""" + os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' test.main()