Fix crash if set_visible_devices() is used with tf.keras.mixed_precision.

Unfortunately, this required disabling the warning that would appear if mixed precision was used on a GPU that didn't fully support it. A warning will still appear if there is no GPU, but no log will appear if the user does have a GPU, because in that case we cannot tell if the GPU is support or not. I will try to get the warning back by 2.3.

PiperOrigin-RevId: 302561652
Change-Id: Ic73d06a4531a052009e83080de7af257042f33e1
This commit is contained in:
Reed Wanderman-Milne 2020-03-23 18:08:14 -07:00 committed by TensorFlower Gardener
parent 8581cdd0d0
commit f748283ee0
3 changed files with 31 additions and 17 deletions

View File

@ -22,6 +22,7 @@ import itertools
from tensorflow.python.client import device_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import gpu_util
from tensorflow.python.platform import tf_logging
@ -133,7 +134,7 @@ def _log_device_compatibility_check(policy_name, device_attr_list):
_logged_compatibility_check = False
def log_device_compatibility_check(policy_name):
def log_device_compatibility_check(policy_name, skip_local):
"""Logs a compatibility check if the devices support the policy.
Currently only logs for the policy mixed_float16. A log is shown only the
@ -141,6 +142,11 @@ def log_device_compatibility_check(policy_name):
Args:
policy_name: The name of the dtype policy.
skip_local: If True, do not call list_local_devices(). This is useful since
if list_local_devices() and tf.config.set_visible_devices() are both
called, TensorFlow will crash. However, since GPU names and compute
capabilities cannot be checked without list_local_devices(), setting this
to True means the function will only warn if there are no GPUs.
"""
global _logged_compatibility_check
# In graph mode, calling list_local_devices may initialize some session state,
@ -149,5 +155,16 @@ def log_device_compatibility_check(policy_name):
return
_logged_compatibility_check = True
device_attr_list = device_lib.list_local_devices()
_log_device_compatibility_check(policy_name, device_attr_list)
if not skip_local:
_log_device_compatibility_check(policy_name, device_attr_list)
return
# TODO(b/146009447): Create an API to replace list_local_devices(), then
# remove the skip_local paramater.
gpus = config.list_physical_devices('GPU')
if not gpus and policy_name == 'mixed_float16':
tf_logging.warn(
'%s\n'
'The dtype policy mixed_float16 may run slowly because '
'this machine does not have a GPU.\n%s' %
(_COMPAT_CHECK_WARNING_PREFIX, _COMPAT_CHECK_WARNING_SUFFIX))

View File

@ -333,7 +333,8 @@ class Policy(object):
self._loss_scale = keras_loss_scale_module.get(loss_scale)
if name in ('mixed_float16', 'mixed_bloat16'):
device_compatibility_check.log_device_compatibility_check(name)
device_compatibility_check.log_device_compatibility_check(name,
skip_local=True)
def _parse_name(self, name):
"""Parses a Policy name into a compute and variable dtype.

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.eager import context
from tensorflow.python.framework import config as config_module
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras import combinations
@ -173,25 +174,20 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
def test_device_compatibility_warning(self):
with context.eager_mode():
device_compatibility_check._logged_compatibility_check = False
with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \
test.mock.patch.object(tf_logging, 'info') as mock_info:
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('mixed_float16')
if mock_warn.called:
if config_module.list_physical_devices('GPU'):
mock_warn.assert_not_called()
else:
self.assertRegexpMatches(
mock_warn.call_args[0][0],
r'Mixed precision compatibility check \(mixed_float16\): WARNING.*')
mock_info.assert_not_called()
else:
self.assertRegexpMatches(
mock_info.call_args[0][0],
r'Mixed precision compatibility check \(mixed_float16\): OK.*')
# Assert message is only logged once
with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \
test.mock.patch.object(tf_logging, 'info') as mock_info:
mp_policy.Policy('mixed_float16')
mock_warn.assert_not_called()
mock_info.assert_not_called()
if config_module.list_physical_devices('GPU'):
# Assert message is only logged once
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('mixed_float16')
mock_warn.assert_not_called()
@testing_utils.enable_v2_dtype_behavior
def test_policy_scope(self):