Make default Keras ConfigProto use tf.config

PiperOrigin-RevId: 251659257
This commit is contained in:
Gaurav Jain 2019-06-05 09:38:01 -07:00 committed by TensorFlower Gardener
parent d15c612f77
commit 26d3fe8711
2 changed files with 38 additions and 8 deletions

View File

@ -69,6 +69,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
@ -521,14 +522,14 @@ def set_session(session):
def get_default_session_config():
if not os.environ.get('OMP_NUM_THREADS'):
config = config_pb2.ConfigProto(allow_soft_placement=True)
else:
num_thread = int(os.environ.get('OMP_NUM_THREADS'))
config = config_pb2.ConfigProto(
intra_op_parallelism_threads=num_thread,
inter_op_parallelism_threads=num_thread,
allow_soft_placement=True)
if os.environ.get('OMP_NUM_THREADS'):
logging.warning(
'OMP_NUM_THREADS is no longer used by the default Keras config. '
'To configure the number of threads, use tf.config.threading APIs.')
config = context.context().config
config.allow_soft_placement = True
return config

View File

@ -24,6 +24,7 @@ import scipy.sparse
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@ -94,6 +95,34 @@ def compare_two_inputs_op_to_numpy(keras_op,
class BackendResetTest(test.TestCase, parameterized.TestCase):
@test_util.run_all_in_graph_and_eager_modes
def test_new_config(self):
# User defined jit setting
config.set_optimizer_jit(False)
sess = keras.backend.get_session()
default_config = context.context().config
self.assertEqual(
sess._config.graph_options.optimizer_options.global_jit_level,
default_config.graph_options.optimizer_options.global_jit_level)
keras.backend.clear_session()
# New session has the same jit setting
sess = keras.backend.get_session()
default_config = context.context().config
self.assertEqual(
sess._config.graph_options.optimizer_options.global_jit_level,
default_config.graph_options.optimizer_options.global_jit_level)
keras.backend.clear_session()
# Change respected
config.set_optimizer_jit(True)
sess = keras.backend.get_session()
default_config = context.context().config
self.assertEqual(
sess._config.graph_options.optimizer_options.global_jit_level,
default_config.graph_options.optimizer_options.global_jit_level)
keras.backend.clear_session()
# We can't use the normal parameterized decorator because the test session
# will block graph clearing.
@parameterized.named_parameters(('_v1', context.graph_mode),