Merge branch 'pr-2361' (Fixes #2361)
This commit is contained in:
commit
2bf8161ca4
|
@ -28,16 +28,6 @@ Config = ConfigSingleton() # pylint: disable=invalid-name
|
||||||
def initialize_globals():
|
def initialize_globals():
|
||||||
c = AttrDict()
|
c = AttrDict()
|
||||||
|
|
||||||
# CPU device
|
|
||||||
c.cpu_device = '/cpu:0'
|
|
||||||
|
|
||||||
# Available GPU devices
|
|
||||||
c.available_devices = get_available_gpus()
|
|
||||||
|
|
||||||
# If there is no GPU available, we fall back to CPU based operation
|
|
||||||
if not c.available_devices:
|
|
||||||
c.available_devices = [c.cpu_device]
|
|
||||||
|
|
||||||
# Set default dropout rates
|
# Set default dropout rates
|
||||||
if FLAGS.dropout_rate2 < 0:
|
if FLAGS.dropout_rate2 < 0:
|
||||||
FLAGS.dropout_rate2 = FLAGS.dropout_rate
|
FLAGS.dropout_rate2 = FLAGS.dropout_rate
|
||||||
|
@ -60,7 +50,18 @@ def initialize_globals():
|
||||||
# Standard session configuration that'll be used for all new sessions.
|
# Standard session configuration that'll be used for all new sessions.
|
||||||
c.session_config = tfv1.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
|
c.session_config = tfv1.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement,
|
||||||
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
|
inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads,
|
||||||
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads)
|
intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads,
|
||||||
|
gpu_options=tfv1.GPUOptions(allow_growth=FLAGS.use_allow_growth))
|
||||||
|
|
||||||
|
# CPU device
|
||||||
|
c.cpu_device = '/cpu:0'
|
||||||
|
|
||||||
|
# Available GPU devices
|
||||||
|
c.available_devices = get_available_gpus(c.session_config)
|
||||||
|
|
||||||
|
# If there is no GPU available, we fall back to CPU based operation
|
||||||
|
if not c.available_devices:
|
||||||
|
c.available_devices = [c.cpu_device]
|
||||||
|
|
||||||
c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))
|
c.alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))
|
||||||
|
|
||||||
|
|
|
@ -76,6 +76,7 @@ def create_flags():
|
||||||
|
|
||||||
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||||
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||||
|
f.DEFINE_boolean('use_allow_growth', False, 'use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory')
|
||||||
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
||||||
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
|
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
from tensorflow.python.client import device_lib
|
from tensorflow.python.client import device_lib
|
||||||
|
|
||||||
def get_available_gpus():
|
|
||||||
|
def get_available_gpus(config):
|
||||||
r"""
|
r"""
|
||||||
Returns the number of GPUs available on this system.
|
Returns the number of GPUs available on this system.
|
||||||
"""
|
"""
|
||||||
local_device_protos = device_lib.list_local_devices()
|
local_device_protos = device_lib.list_local_devices(session_config=config)
|
||||||
return [x.name for x in local_device_protos if x.device_type == 'GPU']
|
return [x.name for x in local_device_protos if x.device_type == 'GPU']
|
||||||
|
|
Loading…
Reference in New Issue