Fix crash if set_visible_devices() is used with tf.keras.mixed_precision.
Unfortunately, this required disabling the warning that would appear if mixed precision was used on a GPU that didn't fully support it. A warning will still appear if there is no GPU, but no log will appear if the user does have a GPU, because in that case we cannot tell if the GPU is support or not. I will try to get the warning back by 2.3. PiperOrigin-RevId: 302561652 Change-Id: Ic73d06a4531a052009e83080de7af257042f33e1
This commit is contained in:
parent
8581cdd0d0
commit
f748283ee0
@ -22,6 +22,7 @@ import itertools
|
|||||||
|
|
||||||
from tensorflow.python.client import device_lib
|
from tensorflow.python.client import device_lib
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import gpu_util
|
from tensorflow.python.framework import gpu_util
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
@ -133,7 +134,7 @@ def _log_device_compatibility_check(policy_name, device_attr_list):
|
|||||||
_logged_compatibility_check = False
|
_logged_compatibility_check = False
|
||||||
|
|
||||||
|
|
||||||
def log_device_compatibility_check(policy_name):
|
def log_device_compatibility_check(policy_name, skip_local):
|
||||||
"""Logs a compatibility check if the devices support the policy.
|
"""Logs a compatibility check if the devices support the policy.
|
||||||
|
|
||||||
Currently only logs for the policy mixed_float16. A log is shown only the
|
Currently only logs for the policy mixed_float16. A log is shown only the
|
||||||
@ -141,6 +142,11 @@ def log_device_compatibility_check(policy_name):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_name: The name of the dtype policy.
|
policy_name: The name of the dtype policy.
|
||||||
|
skip_local: If True, do not call list_local_devices(). This is useful since
|
||||||
|
if list_local_devices() and tf.config.set_visible_devices() are both
|
||||||
|
called, TensorFlow will crash. However, since GPU names and compute
|
||||||
|
capabilities cannot be checked without list_local_devices(), setting this
|
||||||
|
to True means the function will only warn if there are no GPUs.
|
||||||
"""
|
"""
|
||||||
global _logged_compatibility_check
|
global _logged_compatibility_check
|
||||||
# In graph mode, calling list_local_devices may initialize some session state,
|
# In graph mode, calling list_local_devices may initialize some session state,
|
||||||
@ -149,5 +155,16 @@ def log_device_compatibility_check(policy_name):
|
|||||||
return
|
return
|
||||||
_logged_compatibility_check = True
|
_logged_compatibility_check = True
|
||||||
device_attr_list = device_lib.list_local_devices()
|
device_attr_list = device_lib.list_local_devices()
|
||||||
_log_device_compatibility_check(policy_name, device_attr_list)
|
if not skip_local:
|
||||||
|
_log_device_compatibility_check(policy_name, device_attr_list)
|
||||||
|
return
|
||||||
|
|
||||||
|
# TODO(b/146009447): Create an API to replace list_local_devices(), then
|
||||||
|
# remove the skip_local paramater.
|
||||||
|
gpus = config.list_physical_devices('GPU')
|
||||||
|
if not gpus and policy_name == 'mixed_float16':
|
||||||
|
tf_logging.warn(
|
||||||
|
'%s\n'
|
||||||
|
'The dtype policy mixed_float16 may run slowly because '
|
||||||
|
'this machine does not have a GPU.\n%s' %
|
||||||
|
(_COMPAT_CHECK_WARNING_PREFIX, _COMPAT_CHECK_WARNING_SUFFIX))
|
||||||
|
@ -333,7 +333,8 @@ class Policy(object):
|
|||||||
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
||||||
|
|
||||||
if name in ('mixed_float16', 'mixed_bloat16'):
|
if name in ('mixed_float16', 'mixed_bloat16'):
|
||||||
device_compatibility_check.log_device_compatibility_check(name)
|
device_compatibility_check.log_device_compatibility_check(name,
|
||||||
|
skip_local=True)
|
||||||
|
|
||||||
def _parse_name(self, name):
|
def _parse_name(self, name):
|
||||||
"""Parses a Policy name into a compute and variable dtype.
|
"""Parses a Policy name into a compute and variable dtype.
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import config as config_module
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import combinations
|
from tensorflow.python.keras import combinations
|
||||||
@ -173,25 +174,20 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
|
|||||||
def test_device_compatibility_warning(self):
|
def test_device_compatibility_warning(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
device_compatibility_check._logged_compatibility_check = False
|
device_compatibility_check._logged_compatibility_check = False
|
||||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
test.mock.patch.object(tf_logging, 'info') as mock_info:
|
|
||||||
mp_policy.Policy('mixed_float16')
|
mp_policy.Policy('mixed_float16')
|
||||||
if mock_warn.called:
|
if config_module.list_physical_devices('GPU'):
|
||||||
|
mock_warn.assert_not_called()
|
||||||
|
else:
|
||||||
self.assertRegexpMatches(
|
self.assertRegexpMatches(
|
||||||
mock_warn.call_args[0][0],
|
mock_warn.call_args[0][0],
|
||||||
r'Mixed precision compatibility check \(mixed_float16\): WARNING.*')
|
r'Mixed precision compatibility check \(mixed_float16\): WARNING.*')
|
||||||
mock_info.assert_not_called()
|
|
||||||
else:
|
|
||||||
self.assertRegexpMatches(
|
|
||||||
mock_info.call_args[0][0],
|
|
||||||
r'Mixed precision compatibility check \(mixed_float16\): OK.*')
|
|
||||||
|
|
||||||
# Assert message is only logged once
|
if config_module.list_physical_devices('GPU'):
|
||||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \
|
# Assert message is only logged once
|
||||||
test.mock.patch.object(tf_logging, 'info') as mock_info:
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
mp_policy.Policy('mixed_float16')
|
mp_policy.Policy('mixed_float16')
|
||||||
mock_warn.assert_not_called()
|
mock_warn.assert_not_called()
|
||||||
mock_info.assert_not_called()
|
|
||||||
|
|
||||||
@testing_utils.enable_v2_dtype_behavior
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
def test_policy_scope(self):
|
def test_policy_scope(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user