Make default Keras ConfigProto use tf.config
PiperOrigin-RevId: 251659257
This commit is contained in:
parent
d15c612f77
commit
26d3fe8711
@ -69,6 +69,7 @@ from tensorflow.python.ops import state_ops
|
|||||||
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
from tensorflow.python.ops import variables as variables_module
|
from tensorflow.python.ops import variables as variables_module
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
@ -521,14 +522,14 @@ def set_session(session):
|
|||||||
|
|
||||||
|
|
||||||
def get_default_session_config():
|
def get_default_session_config():
|
||||||
if not os.environ.get('OMP_NUM_THREADS'):
|
if os.environ.get('OMP_NUM_THREADS'):
|
||||||
config = config_pb2.ConfigProto(allow_soft_placement=True)
|
logging.warning(
|
||||||
else:
|
'OMP_NUM_THREADS is no longer used by the default Keras config. '
|
||||||
num_thread = int(os.environ.get('OMP_NUM_THREADS'))
|
'To configure the number of threads, use tf.config.threading APIs.')
|
||||||
config = config_pb2.ConfigProto(
|
|
||||||
intra_op_parallelism_threads=num_thread,
|
config = context.context().config
|
||||||
inter_op_parallelism_threads=num_thread,
|
config.allow_soft_placement = True
|
||||||
allow_soft_placement=True)
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ import scipy.sparse
|
|||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -94,6 +95,34 @@ def compare_two_inputs_op_to_numpy(keras_op,
|
|||||||
|
|
||||||
class BackendResetTest(test.TestCase, parameterized.TestCase):
|
class BackendResetTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
def test_new_config(self):
|
||||||
|
# User defined jit setting
|
||||||
|
config.set_optimizer_jit(False)
|
||||||
|
sess = keras.backend.get_session()
|
||||||
|
default_config = context.context().config
|
||||||
|
self.assertEqual(
|
||||||
|
sess._config.graph_options.optimizer_options.global_jit_level,
|
||||||
|
default_config.graph_options.optimizer_options.global_jit_level)
|
||||||
|
keras.backend.clear_session()
|
||||||
|
|
||||||
|
# New session has the same jit setting
|
||||||
|
sess = keras.backend.get_session()
|
||||||
|
default_config = context.context().config
|
||||||
|
self.assertEqual(
|
||||||
|
sess._config.graph_options.optimizer_options.global_jit_level,
|
||||||
|
default_config.graph_options.optimizer_options.global_jit_level)
|
||||||
|
keras.backend.clear_session()
|
||||||
|
|
||||||
|
# Change respected
|
||||||
|
config.set_optimizer_jit(True)
|
||||||
|
sess = keras.backend.get_session()
|
||||||
|
default_config = context.context().config
|
||||||
|
self.assertEqual(
|
||||||
|
sess._config.graph_options.optimizer_options.global_jit_level,
|
||||||
|
default_config.graph_options.optimizer_options.global_jit_level)
|
||||||
|
keras.backend.clear_session()
|
||||||
|
|
||||||
# We can't use the normal parameterized decorator because the test session
|
# We can't use the normal parameterized decorator because the test session
|
||||||
# will block graph clearing.
|
# will block graph clearing.
|
||||||
@parameterized.named_parameters(('_v1', context.graph_mode),
|
@parameterized.named_parameters(('_v1', context.graph_mode),
|
||||||
|
Loading…
Reference in New Issue
Block a user