Fix crash if set_visible_devices() is used with tf.keras.mixed_precision.
Unfortunately, this required disabling the warning that would appear if mixed precision was used on a GPU that didn't fully support it. A warning will still appear if there is no GPU, but no log will appear if the user does have a GPU, because in that case we cannot tell if the GPU is support or not. I will try to get the warning back by 2.3. PiperOrigin-RevId: 302561652 Change-Id: Ic73d06a4531a052009e83080de7af257042f33e1
This commit is contained in:
parent
8581cdd0d0
commit
f748283ee0
@ -22,6 +22,7 @@ import itertools
|
||||
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import gpu_util
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
@ -133,7 +134,7 @@ def _log_device_compatibility_check(policy_name, device_attr_list):
|
||||
_logged_compatibility_check = False
|
||||
|
||||
|
||||
def log_device_compatibility_check(policy_name):
|
||||
def log_device_compatibility_check(policy_name, skip_local):
|
||||
"""Logs a compatibility check if the devices support the policy.
|
||||
|
||||
Currently only logs for the policy mixed_float16. A log is shown only the
|
||||
@ -141,6 +142,11 @@ def log_device_compatibility_check(policy_name):
|
||||
|
||||
Args:
|
||||
policy_name: The name of the dtype policy.
|
||||
skip_local: If True, do not call list_local_devices(). This is useful since
|
||||
if list_local_devices() and tf.config.set_visible_devices() are both
|
||||
called, TensorFlow will crash. However, since GPU names and compute
|
||||
capabilities cannot be checked without list_local_devices(), setting this
|
||||
to True means the function will only warn if there are no GPUs.
|
||||
"""
|
||||
global _logged_compatibility_check
|
||||
# In graph mode, calling list_local_devices may initialize some session state,
|
||||
@ -149,5 +155,16 @@ def log_device_compatibility_check(policy_name):
|
||||
return
|
||||
_logged_compatibility_check = True
|
||||
device_attr_list = device_lib.list_local_devices()
|
||||
_log_device_compatibility_check(policy_name, device_attr_list)
|
||||
if not skip_local:
|
||||
_log_device_compatibility_check(policy_name, device_attr_list)
|
||||
return
|
||||
|
||||
# TODO(b/146009447): Create an API to replace list_local_devices(), then
|
||||
# remove the skip_local paramater.
|
||||
gpus = config.list_physical_devices('GPU')
|
||||
if not gpus and policy_name == 'mixed_float16':
|
||||
tf_logging.warn(
|
||||
'%s\n'
|
||||
'The dtype policy mixed_float16 may run slowly because '
|
||||
'this machine does not have a GPU.\n%s' %
|
||||
(_COMPAT_CHECK_WARNING_PREFIX, _COMPAT_CHECK_WARNING_SUFFIX))
|
||||
|
@ -333,7 +333,8 @@ class Policy(object):
|
||||
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
||||
|
||||
if name in ('mixed_float16', 'mixed_bloat16'):
|
||||
device_compatibility_check.log_device_compatibility_check(name)
|
||||
device_compatibility_check.log_device_compatibility_check(name,
|
||||
skip_local=True)
|
||||
|
||||
def _parse_name(self, name):
|
||||
"""Parses a Policy name into a compute and variable dtype.
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import config as config_module
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import combinations
|
||||
@ -173,25 +174,20 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
|
||||
def test_device_compatibility_warning(self):
|
||||
with context.eager_mode():
|
||||
device_compatibility_check._logged_compatibility_check = False
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \
|
||||
test.mock.patch.object(tf_logging, 'info') as mock_info:
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('mixed_float16')
|
||||
if mock_warn.called:
|
||||
if config_module.list_physical_devices('GPU'):
|
||||
mock_warn.assert_not_called()
|
||||
else:
|
||||
self.assertRegexpMatches(
|
||||
mock_warn.call_args[0][0],
|
||||
r'Mixed precision compatibility check \(mixed_float16\): WARNING.*')
|
||||
mock_info.assert_not_called()
|
||||
else:
|
||||
self.assertRegexpMatches(
|
||||
mock_info.call_args[0][0],
|
||||
r'Mixed precision compatibility check \(mixed_float16\): OK.*')
|
||||
|
||||
# Assert message is only logged once
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \
|
||||
test.mock.patch.object(tf_logging, 'info') as mock_info:
|
||||
mp_policy.Policy('mixed_float16')
|
||||
mock_warn.assert_not_called()
|
||||
mock_info.assert_not_called()
|
||||
if config_module.list_physical_devices('GPU'):
|
||||
# Assert message is only logged once
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('mixed_float16')
|
||||
mock_warn.assert_not_called()
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_policy_scope(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user