Merge branch 'pr-2361' (Fixes #2361)

This commit is contained in:
Reuben Morais 2019-09-13 12:16:49 +02:00
commit 2bf8161ca4
3 changed files with 16 additions and 13 deletions

23
util/config.py Normal file → Executable file
View File

@ -28,16 +28,6 @@ Config = ConfigSingleton() # pylint: disable=invalid-name
def initialize_globals(): def initialize_globals():
c = AttrDict() c = AttrDict()
# CPU device
c.cpu_device = '/cpu:0'
# Available GPU devices
c.available_devices = get_available_gpus()
# If there is no GPU available, we fall back to CPU based operation
if not c.available_devices:
c.available_devices = [c.cpu_device]
# Set default dropout rates # Set default dropout rates
if FLAGS.dropout_rate2 < 0: if FLAGS.dropout_rate2 < 0:
FLAGS.dropout_rate2 = FLAGS.dropout_rate FLAGS.dropout_rate2 = FLAGS.dropout_rate
@ -60,7 +50,18 @@ def initialize_globals():
# Standard session configuration that'll be used for all new sessions. # Standard session configuration that'll be used for all new sessions.
c.session_config = tfv1.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement, c.session_config = tfv1.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads) intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads,
gpu_options=tfv1.GPUOptions(allow_growth=FLAGS.use_allow_growth))
# CPU device
c.cpu_device = '/cpu:0'
# Available GPU devices
c.available_devices = get_available_gpus(c.session_config)
# If there is no GPU available, we fall back to CPU based operation
if not c.available_devices:
c.available_devices = [c.cpu_device]
c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path)) c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))

View File

@ -76,6 +76,7 @@ def create_flags():
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_boolean('use_allow_growth', False, 'use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory')
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work') f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.') f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')

5
util/gpu.py Normal file → Executable file
View File

@ -1,8 +1,9 @@
from tensorflow.python.client import device_lib from tensorflow.python.client import device_lib
def get_available_gpus():
def get_available_gpus(config):
r""" r"""
Returns the number of GPUs available on this system. Returns the number of GPUs available on this system.
""" """
local_device_protos = device_lib.list_local_devices() local_device_protos = device_lib.list_local_devices(session_config=config)
return [x.name for x in local_device_protos if x.device_type == 'GPU'] return [x.name for x in local_device_protos if x.device_type == 'GPU']