Export some properties of context.context() as tf.__internal__ APIs.

PiperOrigin-RevId: 341068359
Change-Id: I20e7f3cd6496bbfae66d1f929f60dcf9049a1007
This commit is contained in:
Yanhui Liang 2020-11-06 09:58:57 -08:00 committed by TensorFlower Gardener
parent 71a30c389a
commit e2286aba8e
10 changed files with 81 additions and 13 deletions

View File

@ -2060,6 +2060,47 @@ def device(name):
return context().device(name)
# Expose some properties of Context as internally public APIs (b/160348781).
@tf_export("__internal__.eager_context.get_config", v1=[])
def get_config():
"""Get the ConfigProto of Context.
Returns:
The ConfigProto of Context.
"""
return context().config
@tf_export("__internal__.eager_context.get_device_name", v1=[])
def get_device_name():
"""Get the device name for the current thread.
Returns:
The device name for the current thread.
"""
return context().device_name
@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[])
def set_soft_device_placement(enabled):
"""Set if soft device placements should be allowed.
Args:
enabled: Whether to enable soft device placement.
"""
context().soft_device_placement = enabled
@tf_export("__internal__.eager_context.get_executor", v1=[])
def get_executor():
"""Get the Executor of the current thread.
Returns:
The Executor of the current thread.
"""
return context().executor
@tf_export("debugging.get_log_device_placement")
def get_log_device_placement():
"""Get if device placements are logged.

View File

@ -41,6 +41,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.eager import function as eager_function
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.eager.context import get_config
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
@ -821,7 +822,7 @@ def get_default_session_config():
'OMP_NUM_THREADS is no longer used by the default Keras config. '
'To configure the number of threads, use tf.config.threading APIs.')
config = context.context().config
config = get_config()
config.allow_soft_placement = True
return config

View File

@ -26,6 +26,7 @@ import scipy.sparse
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager.context import get_config
from tensorflow.python.framework import config
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
@ -106,7 +107,7 @@ class BackendResetTest(test.TestCase, parameterized.TestCase):
# User defined jit setting
config.set_optimizer_jit(False)
sess = backend.get_session()
default_config = context.context().config
default_config = get_config()
self.assertEqual(
sess._config.graph_options.optimizer_options.global_jit_level,
default_config.graph_options.optimizer_options.global_jit_level)
@ -114,7 +115,7 @@ class BackendResetTest(test.TestCase, parameterized.TestCase):
# New session has the same jit setting
sess = backend.get_session()
default_config = context.context().config
default_config = get_config()
self.assertEqual(
sess._config.graph_options.optimizer_options.global_jit_level,
default_config.graph_options.optimizer_options.global_jit_level)
@ -123,7 +124,7 @@ class BackendResetTest(test.TestCase, parameterized.TestCase):
# Change respected
config.set_optimizer_jit(True)
sess = backend.get_session()
default_config = context.context().config
default_config = get_config()
self.assertEqual(
sess._config.graph_options.optimizer_options.global_jit_level,
default_config.graph_options.optimizer_options.global_jit_level)

View File

@ -23,22 +23,22 @@ import six
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.eager.context import get_executor
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.platform import benchmark
def _run_benchmark(func, num_iters, execution_mode=None):
ctx = context.context()
with context.execution_mode(execution_mode):
# call func to warm up
func()
if execution_mode == context.ASYNC:
ctx.executor.wait()
get_executor().wait()
start = time.time()
for _ in range(num_iters):
func()
if execution_mode == context.ASYNC:
ctx.executor.wait()
get_executor().wait()
end = time.time()
return end - start

View File

@ -26,6 +26,7 @@ import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.eager import profiler
from tensorflow.python.eager.context import get_executor
from tensorflow.python.platform import test
@ -84,17 +85,16 @@ def make_sequential_keras_model(initializer="ones"):
def run_benchmark(func, num_iters, execution_mode=None):
ctx = context.context()
with context.execution_mode(execution_mode):
# call func to warm up
func()
if execution_mode == context.ASYNC:
ctx.executor.wait()
get_executor().wait()
start = time.time()
for _ in xrange(num_iters):
func()
if execution_mode == context.ASYNC:
ctx.executor.wait()
get_executor().wait()
end = time.time()
return end - start

View File

@ -22,6 +22,7 @@ import uuid
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager.context import get_device_name
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device
@ -1748,7 +1749,7 @@ def _generate_defun_backend(unique_api_name, preferred_device, func,
def _get_context_device_type():
"""Parse the current context and return the device type, eg CPU/GPU."""
current_device = context.context().device_name
current_device = get_device_name()
if current_device is None:
return None
return device.DeviceSpec.from_string(current_device).device_type

View File

@ -30,9 +30,9 @@ from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import tpu_strategy as tpu_strategy_lib
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager.context import set_soft_device_placement
from tensorflow.python.framework import ops
from tensorflow.python.keras import callbacks
from tensorflow.python.keras import initializers
@ -168,7 +168,7 @@ class AutoOutsideCompilationWithKerasTest(test.TestCase):
def setUp(self):
super(AutoOutsideCompilationWithKerasTest, self).setUp()
v2_compat.enable_v2_behavior()
context.context().soft_device_placement = True
set_soft_device_placement(True)
self.summary_dir = self.get_temp_dir()
def validate_recorded_sumary_file(self, event_files, summary_dict,

View File

@ -10,6 +10,7 @@ TENSORFLOW_API_INIT_FILES = [
"__internal__/distribute/__init__.py",
"__internal__/distribute/combinations/__init__.py",
"__internal__/distribute/multi_process_runner/__init__.py",
"__internal__/eager_context/__init__.py",
"__internal__/nest/__init__.py",
"__internal__/ops/__init__.py",
"__internal__/test/__init__.py",

View File

@ -0,0 +1,19 @@
path: "tensorflow.__internal__.eager_context"
tf_module {
member_method {
name: "get_config"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_device_name"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_executor"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_soft_device_placement"
argspec: "args=[\'enabled\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -16,6 +16,10 @@ tf_module {
name: "distribute"
mtype: "<type \'module\'>"
}
member {
name: "eager_context"
mtype: "<type \'module\'>"
}
member {
name: "nest"
mtype: "<type \'module\'>"