diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 73dd97fdb92..f9205ed7b45 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -69,6 +69,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables as variables_module +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import nest from tensorflow.python.util import tf_contextlib @@ -521,14 +522,14 @@ def set_session(session): def get_default_session_config(): - if not os.environ.get('OMP_NUM_THREADS'): - config = config_pb2.ConfigProto(allow_soft_placement=True) - else: - num_thread = int(os.environ.get('OMP_NUM_THREADS')) - config = config_pb2.ConfigProto( - intra_op_parallelism_threads=num_thread, - inter_op_parallelism_threads=num_thread, - allow_soft_placement=True) + if os.environ.get('OMP_NUM_THREADS'): + logging.warning( + 'OMP_NUM_THREADS is no longer used by the default Keras config. ' + 'To configure the number of threads, use tf.config.threading APIs.') + + config = context.context().config + config.allow_soft_placement = True + return config diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index 3507f2ee4a9..e3bc5467261 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -24,6 +24,7 @@ import scipy.sparse from tensorflow.core.protobuf import config_pb2 from tensorflow.python import keras from tensorflow.python.eager import context +from tensorflow.python.framework import config from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -94,6 +95,34 @@ def compare_two_inputs_op_to_numpy(keras_op, class BackendResetTest(test.TestCase, parameterized.TestCase): + @test_util.run_all_in_graph_and_eager_modes + def test_new_config(self): + # User defined jit setting + config.set_optimizer_jit(False) + sess = keras.backend.get_session() + default_config = context.context().config + self.assertEqual( + sess._config.graph_options.optimizer_options.global_jit_level, + default_config.graph_options.optimizer_options.global_jit_level) + keras.backend.clear_session() + + # New session has the same jit setting + sess = keras.backend.get_session() + default_config = context.context().config + self.assertEqual( + sess._config.graph_options.optimizer_options.global_jit_level, + default_config.graph_options.optimizer_options.global_jit_level) + keras.backend.clear_session() + + # Change respected + config.set_optimizer_jit(True) + sess = keras.backend.get_session() + default_config = context.context().config + self.assertEqual( + sess._config.graph_options.optimizer_options.global_jit_level, + default_config.graph_options.optimizer_options.global_jit_level) + keras.backend.clear_session() + # We can't use the normal parameterized decorator because the test session # will block graph clearing. @parameterized.named_parameters(('_v1', context.graph_mode),