Make default Keras ConfigProto use tf.config
PiperOrigin-RevId: 251659257
This commit is contained in:
parent
d15c612f77
commit
26d3fe8711
@ -69,6 +69,7 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
@ -521,14 +522,14 @@ def set_session(session):
|
||||
|
||||
|
||||
def get_default_session_config():
|
||||
if not os.environ.get('OMP_NUM_THREADS'):
|
||||
config = config_pb2.ConfigProto(allow_soft_placement=True)
|
||||
else:
|
||||
num_thread = int(os.environ.get('OMP_NUM_THREADS'))
|
||||
config = config_pb2.ConfigProto(
|
||||
intra_op_parallelism_threads=num_thread,
|
||||
inter_op_parallelism_threads=num_thread,
|
||||
allow_soft_placement=True)
|
||||
if os.environ.get('OMP_NUM_THREADS'):
|
||||
logging.warning(
|
||||
'OMP_NUM_THREADS is no longer used by the default Keras config. '
|
||||
'To configure the number of threads, use tf.config.threading APIs.')
|
||||
|
||||
config = context.context().config
|
||||
config.allow_soft_placement = True
|
||||
|
||||
return config
|
||||
|
||||
|
||||
|
@ -24,6 +24,7 @@ import scipy.sparse
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -94,6 +95,34 @@ def compare_two_inputs_op_to_numpy(keras_op,
|
||||
|
||||
class BackendResetTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
def test_new_config(self):
|
||||
# User defined jit setting
|
||||
config.set_optimizer_jit(False)
|
||||
sess = keras.backend.get_session()
|
||||
default_config = context.context().config
|
||||
self.assertEqual(
|
||||
sess._config.graph_options.optimizer_options.global_jit_level,
|
||||
default_config.graph_options.optimizer_options.global_jit_level)
|
||||
keras.backend.clear_session()
|
||||
|
||||
# New session has the same jit setting
|
||||
sess = keras.backend.get_session()
|
||||
default_config = context.context().config
|
||||
self.assertEqual(
|
||||
sess._config.graph_options.optimizer_options.global_jit_level,
|
||||
default_config.graph_options.optimizer_options.global_jit_level)
|
||||
keras.backend.clear_session()
|
||||
|
||||
# Change respected
|
||||
config.set_optimizer_jit(True)
|
||||
sess = keras.backend.get_session()
|
||||
default_config = context.context().config
|
||||
self.assertEqual(
|
||||
sess._config.graph_options.optimizer_options.global_jit_level,
|
||||
default_config.graph_options.optimizer_options.global_jit_level)
|
||||
keras.backend.clear_session()
|
||||
|
||||
# We can't use the normal parameterized decorator because the test session
|
||||
# will block graph clearing.
|
||||
@parameterized.named_parameters(('_v1', context.graph_mode),
|
||||
|
Loading…
Reference in New Issue
Block a user