From f748283ee01059be52da5dada6e2157d9f6732ba Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Mon, 23 Mar 2020 18:08:14 -0700 Subject: [PATCH] Fix crash if set_visible_devices() is used with tf.keras.mixed_precision. Unfortunately, this required disabling the warning that would appear if mixed precision was used on a GPU that didn't fully support it. A warning will still appear if there is no GPU, but no log will appear if the user does have a GPU, because in that case we cannot tell if the GPU is support or not. I will try to get the warning back by 2.3. PiperOrigin-RevId: 302561652 Change-Id: Ic73d06a4531a052009e83080de7af257042f33e1 --- .../device_compatibility_check.py | 21 ++++++++++++++-- .../mixed_precision/experimental/policy.py | 3 ++- .../experimental/policy_test.py | 24 ++++++++----------- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py index d92c16d632f..9279c37bb52 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py +++ b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py @@ -22,6 +22,7 @@ import itertools from tensorflow.python.client import device_lib from tensorflow.python.eager import context +from tensorflow.python.framework import config from tensorflow.python.framework import gpu_util from tensorflow.python.platform import tf_logging @@ -133,7 +134,7 @@ def _log_device_compatibility_check(policy_name, device_attr_list): _logged_compatibility_check = False -def log_device_compatibility_check(policy_name): +def log_device_compatibility_check(policy_name, skip_local): """Logs a compatibility check if the devices support the policy. Currently only logs for the policy mixed_float16. A log is shown only the @@ -141,6 +142,11 @@ def log_device_compatibility_check(policy_name): Args: policy_name: The name of the dtype policy. + skip_local: If True, do not call list_local_devices(). This is useful since + if list_local_devices() and tf.config.set_visible_devices() are both + called, TensorFlow will crash. However, since GPU names and compute + capabilities cannot be checked without list_local_devices(), setting this + to True means the function will only warn if there are no GPUs. """ global _logged_compatibility_check # In graph mode, calling list_local_devices may initialize some session state, @@ -149,5 +155,16 @@ def log_device_compatibility_check(policy_name): return _logged_compatibility_check = True device_attr_list = device_lib.list_local_devices() - _log_device_compatibility_check(policy_name, device_attr_list) + if not skip_local: + _log_device_compatibility_check(policy_name, device_attr_list) + return + # TODO(b/146009447): Create an API to replace list_local_devices(), then + # remove the skip_local paramater. + gpus = config.list_physical_devices('GPU') + if not gpus and policy_name == 'mixed_float16': + tf_logging.warn( + '%s\n' + 'The dtype policy mixed_float16 may run slowly because ' + 'this machine does not have a GPU.\n%s' % + (_COMPAT_CHECK_WARNING_PREFIX, _COMPAT_CHECK_WARNING_SUFFIX)) diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py index 9afc3ce9251..f9899679a86 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py @@ -333,7 +333,8 @@ class Policy(object): self._loss_scale = keras_loss_scale_module.get(loss_scale) if name in ('mixed_float16', 'mixed_bloat16'): - device_compatibility_check.log_device_compatibility_check(name) + device_compatibility_check.log_device_compatibility_check(name, + skip_local=True) def _parse_name(self, name): """Parses a Policy name into a compute and variable dtype. diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py index b345039b406..ff809d061cb 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.python.eager import context +from tensorflow.python.framework import config as config_module from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.keras import combinations @@ -173,25 +174,20 @@ class PolicyTest(test.TestCase, parameterized.TestCase): def test_device_compatibility_warning(self): with context.eager_mode(): device_compatibility_check._logged_compatibility_check = False - with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \ - test.mock.patch.object(tf_logging, 'info') as mock_info: + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: mp_policy.Policy('mixed_float16') - if mock_warn.called: + if config_module.list_physical_devices('GPU'): + mock_warn.assert_not_called() + else: self.assertRegexpMatches( mock_warn.call_args[0][0], r'Mixed precision compatibility check \(mixed_float16\): WARNING.*') - mock_info.assert_not_called() - else: - self.assertRegexpMatches( - mock_info.call_args[0][0], - r'Mixed precision compatibility check \(mixed_float16\): OK.*') - # Assert message is only logged once - with test.mock.patch.object(tf_logging, 'warn') as mock_warn, \ - test.mock.patch.object(tf_logging, 'info') as mock_info: - mp_policy.Policy('mixed_float16') - mock_warn.assert_not_called() - mock_info.assert_not_called() + if config_module.list_physical_devices('GPU'): + # Assert message is only logged once + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('mixed_float16') + mock_warn.assert_not_called() @testing_utils.enable_v2_dtype_behavior def test_policy_scope(self):