Export some properties of context.context()
as tf.__internal__ APIs.
PiperOrigin-RevId: 341068359 Change-Id: I20e7f3cd6496bbfae66d1f929f60dcf9049a1007
This commit is contained in:
parent
71a30c389a
commit
e2286aba8e
@ -2060,6 +2060,47 @@ def device(name):
|
||||
return context().device(name)
|
||||
|
||||
|
||||
# Expose some properties of Context as internally public APIs (b/160348781).
|
||||
@tf_export("__internal__.eager_context.get_config", v1=[])
|
||||
def get_config():
|
||||
"""Get the ConfigProto of Context.
|
||||
|
||||
Returns:
|
||||
The ConfigProto of Context.
|
||||
"""
|
||||
return context().config
|
||||
|
||||
|
||||
@tf_export("__internal__.eager_context.get_device_name", v1=[])
|
||||
def get_device_name():
|
||||
"""Get the device name for the current thread.
|
||||
|
||||
Returns:
|
||||
The device name for the current thread.
|
||||
"""
|
||||
return context().device_name
|
||||
|
||||
|
||||
@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[])
|
||||
def set_soft_device_placement(enabled):
|
||||
"""Set if soft device placements should be allowed.
|
||||
|
||||
Args:
|
||||
enabled: Whether to enable soft device placement.
|
||||
"""
|
||||
context().soft_device_placement = enabled
|
||||
|
||||
|
||||
@tf_export("__internal__.eager_context.get_executor", v1=[])
|
||||
def get_executor():
|
||||
"""Get the Executor of the current thread.
|
||||
|
||||
Returns:
|
||||
The Executor of the current thread.
|
||||
"""
|
||||
return context().executor
|
||||
|
||||
|
||||
@tf_export("debugging.get_log_device_placement")
|
||||
def get_log_device_placement():
|
||||
"""Get if device placements are logged.
|
||||
|
@ -41,6 +41,7 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function as eager_function
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.eager.context import get_config
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -821,7 +822,7 @@ def get_default_session_config():
|
||||
'OMP_NUM_THREADS is no longer used by the default Keras config. '
|
||||
'To configure the number of threads, use tf.config.threading APIs.')
|
||||
|
||||
config = context.context().config
|
||||
config = get_config()
|
||||
config.allow_soft_placement = True
|
||||
|
||||
return config
|
||||
|
@ -26,6 +26,7 @@ import scipy.sparse
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager.context import get_config
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
@ -106,7 +107,7 @@ class BackendResetTest(test.TestCase, parameterized.TestCase):
|
||||
# User defined jit setting
|
||||
config.set_optimizer_jit(False)
|
||||
sess = backend.get_session()
|
||||
default_config = context.context().config
|
||||
default_config = get_config()
|
||||
self.assertEqual(
|
||||
sess._config.graph_options.optimizer_options.global_jit_level,
|
||||
default_config.graph_options.optimizer_options.global_jit_level)
|
||||
@ -114,7 +115,7 @@ class BackendResetTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
# New session has the same jit setting
|
||||
sess = backend.get_session()
|
||||
default_config = context.context().config
|
||||
default_config = get_config()
|
||||
self.assertEqual(
|
||||
sess._config.graph_options.optimizer_options.global_jit_level,
|
||||
default_config.graph_options.optimizer_options.global_jit_level)
|
||||
@ -123,7 +124,7 @@ class BackendResetTest(test.TestCase, parameterized.TestCase):
|
||||
# Change respected
|
||||
config.set_optimizer_jit(True)
|
||||
sess = backend.get_session()
|
||||
default_config = context.context().config
|
||||
default_config = get_config()
|
||||
self.assertEqual(
|
||||
sess._config.graph_options.optimizer_options.global_jit_level,
|
||||
default_config.graph_options.optimizer_options.global_jit_level)
|
||||
|
@ -23,22 +23,22 @@ import six
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager.context import get_executor
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.platform import benchmark
|
||||
|
||||
|
||||
def _run_benchmark(func, num_iters, execution_mode=None):
|
||||
ctx = context.context()
|
||||
with context.execution_mode(execution_mode):
|
||||
# call func to warm up
|
||||
func()
|
||||
if execution_mode == context.ASYNC:
|
||||
ctx.executor.wait()
|
||||
get_executor().wait()
|
||||
start = time.time()
|
||||
for _ in range(num_iters):
|
||||
func()
|
||||
if execution_mode == context.ASYNC:
|
||||
ctx.executor.wait()
|
||||
get_executor().wait()
|
||||
end = time.time()
|
||||
|
||||
return end - start
|
||||
|
@ -26,6 +26,7 @@ import tensorflow as tf
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import profiler
|
||||
from tensorflow.python.eager.context import get_executor
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -84,17 +85,16 @@ def make_sequential_keras_model(initializer="ones"):
|
||||
|
||||
|
||||
def run_benchmark(func, num_iters, execution_mode=None):
|
||||
ctx = context.context()
|
||||
with context.execution_mode(execution_mode):
|
||||
# call func to warm up
|
||||
func()
|
||||
if execution_mode == context.ASYNC:
|
||||
ctx.executor.wait()
|
||||
get_executor().wait()
|
||||
start = time.time()
|
||||
for _ in xrange(num_iters):
|
||||
func()
|
||||
if execution_mode == context.ASYNC:
|
||||
ctx.executor.wait()
|
||||
get_executor().wait()
|
||||
end = time.time()
|
||||
|
||||
return end - start
|
||||
|
@ -22,6 +22,7 @@ import uuid
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager.context import get_device_name
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device
|
||||
@ -1748,7 +1749,7 @@ def _generate_defun_backend(unique_api_name, preferred_device, func,
|
||||
|
||||
def _get_context_device_type():
|
||||
"""Parse the current context and return the device type, eg CPU/GPU."""
|
||||
current_device = context.context().device_name
|
||||
current_device = get_device_name()
|
||||
if current_device is None:
|
||||
return None
|
||||
return device.DeviceSpec.from_string(current_device).device_type
|
||||
|
@ -30,9 +30,9 @@ from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import tpu_strategy as tpu_strategy_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.eager.context import set_soft_device_placement
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import callbacks
|
||||
from tensorflow.python.keras import initializers
|
||||
@ -168,7 +168,7 @@ class AutoOutsideCompilationWithKerasTest(test.TestCase):
|
||||
def setUp(self):
|
||||
super(AutoOutsideCompilationWithKerasTest, self).setUp()
|
||||
v2_compat.enable_v2_behavior()
|
||||
context.context().soft_device_placement = True
|
||||
set_soft_device_placement(True)
|
||||
self.summary_dir = self.get_temp_dir()
|
||||
|
||||
def validate_recorded_sumary_file(self, event_files, summary_dict,
|
||||
|
@ -10,6 +10,7 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"__internal__/distribute/__init__.py",
|
||||
"__internal__/distribute/combinations/__init__.py",
|
||||
"__internal__/distribute/multi_process_runner/__init__.py",
|
||||
"__internal__/eager_context/__init__.py",
|
||||
"__internal__/nest/__init__.py",
|
||||
"__internal__/ops/__init__.py",
|
||||
"__internal__/test/__init__.py",
|
||||
|
@ -0,0 +1,19 @@
|
||||
path: "tensorflow.__internal__.eager_context"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_device_name"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_executor"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_soft_device_placement"
|
||||
argspec: "args=[\'enabled\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -16,6 +16,10 @@ tf_module {
|
||||
name: "distribute"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "eager_context"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "nest"
|
||||
mtype: "<type \'module\'>"
|
||||
|
Loading…
x
Reference in New Issue
Block a user