diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 6c1047bcd52..fc6a744a673 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -2060,6 +2060,47 @@ def device(name): return context().device(name) +# Expose some properties of Context as internally public APIs (b/160348781). +@tf_export("__internal__.eager_context.get_config", v1=[]) +def get_config(): + """Get the ConfigProto of Context. + + Returns: + The ConfigProto of Context. + """ + return context().config + + +@tf_export("__internal__.eager_context.get_device_name", v1=[]) +def get_device_name(): + """Get the device name for the current thread. + + Returns: + The device name for the current thread. + """ + return context().device_name + + +@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[]) +def set_soft_device_placement(enabled): + """Set if soft device placements should be allowed. + + Args: + enabled: Whether to enable soft device placement. + """ + context().soft_device_placement = enabled + + +@tf_export("__internal__.eager_context.get_executor", v1=[]) +def get_executor(): + """Get the Executor of the current thread. + + Returns: + The Executor of the current thread. + """ + return context().executor + + @tf_export("debugging.get_log_device_placement") def get_log_device_placement(): """Get if device placements are logged. diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 6c5ce5755f6..c1da720c8ff 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -41,6 +41,7 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.eager import context from tensorflow.python.eager import function as eager_function from tensorflow.python.eager import lift_to_graph +from tensorflow.python.eager.context import get_config from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import config from tensorflow.python.framework import constant_op @@ -821,7 +822,7 @@ def get_default_session_config(): 'OMP_NUM_THREADS is no longer used by the default Keras config. ' 'To configure the number of threads, use tf.config.threading APIs.') - config = context.context().config + config = get_config() config.allow_soft_placement = True return config diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index 4dd2a45eba6..8943afaf993 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -26,6 +26,7 @@ import scipy.sparse from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.eager.context import get_config from tensorflow.python.framework import config from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -106,7 +107,7 @@ class BackendResetTest(test.TestCase, parameterized.TestCase): # User defined jit setting config.set_optimizer_jit(False) sess = backend.get_session() - default_config = context.context().config + default_config = get_config() self.assertEqual( sess._config.graph_options.optimizer_options.global_jit_level, default_config.graph_options.optimizer_options.global_jit_level) @@ -114,7 +115,7 @@ class BackendResetTest(test.TestCase, parameterized.TestCase): # New session has the same jit setting sess = backend.get_session() - default_config = context.context().config + default_config = get_config() self.assertEqual( sess._config.graph_options.optimizer_options.global_jit_level, default_config.graph_options.optimizer_options.global_jit_level) @@ -123,7 +124,7 @@ class BackendResetTest(test.TestCase, parameterized.TestCase): # Change respected config.set_optimizer_jit(True) sess = backend.get_session() - default_config = context.context().config + default_config = get_config() self.assertEqual( sess._config.graph_options.optimizer_options.global_jit_level, default_config.graph_options.optimizer_options.global_jit_level) diff --git a/tensorflow/python/keras/benchmarks/eager_microbenchmarks_test.py b/tensorflow/python/keras/benchmarks/eager_microbenchmarks_test.py index d8e004689d3..e7fcd65a9f6 100644 --- a/tensorflow/python/keras/benchmarks/eager_microbenchmarks_test.py +++ b/tensorflow/python/keras/benchmarks/eager_microbenchmarks_test.py @@ -23,22 +23,22 @@ import six import tensorflow as tf from tensorflow.python.eager import context +from tensorflow.python.eager.context import get_executor from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.platform import benchmark def _run_benchmark(func, num_iters, execution_mode=None): - ctx = context.context() with context.execution_mode(execution_mode): # call func to warm up func() if execution_mode == context.ASYNC: - ctx.executor.wait() + get_executor().wait() start = time.time() for _ in range(num_iters): func() if execution_mode == context.ASYNC: - ctx.executor.wait() + get_executor().wait() end = time.time() return end - start diff --git a/tensorflow/python/keras/benchmarks/model_components_benchmarks_test.py b/tensorflow/python/keras/benchmarks/model_components_benchmarks_test.py index 624c318bedb..c0a8a255f57 100644 --- a/tensorflow/python/keras/benchmarks/model_components_benchmarks_test.py +++ b/tensorflow/python/keras/benchmarks/model_components_benchmarks_test.py @@ -26,6 +26,7 @@ import tensorflow as tf from tensorflow.python.eager import context from tensorflow.python.eager import profiler +from tensorflow.python.eager.context import get_executor from tensorflow.python.platform import test @@ -84,17 +85,16 @@ def make_sequential_keras_model(initializer="ones"): def run_benchmark(func, num_iters, execution_mode=None): - ctx = context.context() with context.execution_mode(execution_mode): # call func to warm up func() if execution_mode == context.ASYNC: - ctx.executor.wait() + get_executor().wait() start = time.time() for _ in xrange(num_iters): func() if execution_mode == context.ASYNC: - ctx.executor.wait() + get_executor().wait() end = time.time() return end - start diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py index 1fa17d21773..263a341ea61 100644 --- a/tensorflow/python/keras/layers/recurrent_v2.py +++ b/tensorflow/python/keras/layers/recurrent_v2.py @@ -22,6 +22,7 @@ import uuid from tensorflow.python.eager import context from tensorflow.python.eager import function +from tensorflow.python.eager.context import get_device_name from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import device @@ -1748,7 +1749,7 @@ def _generate_defun_backend(unique_api_name, preferred_device, func, def _get_context_device_type(): """Parse the current context and return the device type, eg CPU/GPU.""" - current_device = context.context().device_name + current_device = get_device_name() if current_device is None: return None return device.DeviceSpec.from_string(current_device).device_type diff --git a/tensorflow/python/keras/tests/automatic_outside_compilation_test.py b/tensorflow/python/keras/tests/automatic_outside_compilation_test.py index cba6adbb35a..7538ab1a8f3 100644 --- a/tensorflow/python/keras/tests/automatic_outside_compilation_test.py +++ b/tensorflow/python/keras/tests/automatic_outside_compilation_test.py @@ -30,9 +30,9 @@ from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import tpu_strategy as tpu_strategy_lib from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver -from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import remote +from tensorflow.python.eager.context import set_soft_device_placement from tensorflow.python.framework import ops from tensorflow.python.keras import callbacks from tensorflow.python.keras import initializers @@ -168,7 +168,7 @@ class AutoOutsideCompilationWithKerasTest(test.TestCase): def setUp(self): super(AutoOutsideCompilationWithKerasTest, self).setUp() v2_compat.enable_v2_behavior() - context.context().soft_device_placement = True + set_soft_device_placement(True) self.summary_dir = self.get_temp_dir() def validate_recorded_sumary_file(self, event_files, summary_dict, diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index 2e0f1ad8536..6dd3f88694f 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -10,6 +10,7 @@ TENSORFLOW_API_INIT_FILES = [ "__internal__/distribute/__init__.py", "__internal__/distribute/combinations/__init__.py", "__internal__/distribute/multi_process_runner/__init__.py", + "__internal__/eager_context/__init__.py", "__internal__/nest/__init__.py", "__internal__/ops/__init__.py", "__internal__/test/__init__.py", diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.eager_context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.eager_context.pbtxt new file mode 100644 index 00000000000..70c4d74f936 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.eager_context.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.__internal__.eager_context" +tf_module { + member_method { + name: "get_config" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_device_name" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_executor" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_soft_device_placement" + argspec: "args=[\'enabled\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt index 6bd4e43b158..4b9c7a0c761 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.pbtxt @@ -16,6 +16,10 @@ tf_module { name: "distribute" mtype: "" } + member { + name: "eager_context" + mtype: "" + } member { name: "nest" mtype: ""