Ensure function_call_options use current config
In addition, we ensure the placer gets the session options. This requires ensuring that the correct defaults are being set. PiperOrigin-RevId: 237881614
This commit is contained in:
parent
f95c2d19ed
commit
df3a9e555f
@ -581,11 +581,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
DumpGraph("Before calling Placer", graph.get());
|
||||
// TODO(b/124993244): Smartly merge options in nested defuns, and raise
|
||||
// exceptions/warnings in case where nested function call options are ignored.
|
||||
// TODO(b/125933502): Currently config proto in function call options is not
|
||||
// respected by placer, because placer and config proto has different default
|
||||
// behaviors (allowing soft placement by default, vs. not allowing it). Pass
|
||||
// config proto with appropriate default values to placer here.
|
||||
Placer placer(graph.get(), &device_set, nullptr, /* No session options */
|
||||
Placer placer(graph.get(), &device_set, &session_options,
|
||||
flr->device() /* Default device */);
|
||||
TF_RETURN_IF_ERROR(placer.Run());
|
||||
|
||||
|
@ -134,7 +134,7 @@ class FunctionCallOptions(object):
|
||||
class _ThreadLocalData(threading.local):
|
||||
"""Thread local storage for the eager context."""
|
||||
|
||||
def __init__(self, config=None):
|
||||
def __init__(self):
|
||||
super(_ThreadLocalData, self).__init__()
|
||||
self.device_spec = _starting_device_spec
|
||||
self.device_name = ""
|
||||
@ -148,29 +148,7 @@ class _ThreadLocalData(threading.local):
|
||||
self._ones_rank_cache = None
|
||||
self._zeros_cache = None
|
||||
self.execution_mode = SYNC
|
||||
|
||||
# Default rewriter config corresponds to turning all default grappler
|
||||
# optimizations on.
|
||||
self._config = config
|
||||
|
||||
self._function_call_options = None
|
||||
|
||||
@property
|
||||
def function_call_options(self):
|
||||
if self._function_call_options is None:
|
||||
base_config = config_pb2.ConfigProto()
|
||||
if self._config is not None:
|
||||
base_config.MergeFrom(self._config)
|
||||
self._config = None
|
||||
self._function_call_options = FunctionCallOptions(
|
||||
config_proto=base_config)
|
||||
|
||||
return self._function_call_options
|
||||
|
||||
@function_call_options.setter
|
||||
def function_call_options(self, function_call_options):
|
||||
self._function_call_options = function_call_options
|
||||
self._config = None
|
||||
self.function_call_options = None
|
||||
|
||||
@property
|
||||
def ones_rank_cache(self):
|
||||
@ -283,14 +261,17 @@ class Context(object):
|
||||
Raises:
|
||||
ValueError: If execution_mode is not valid.
|
||||
"""
|
||||
self._thread_local_data = _ThreadLocalData(config)
|
||||
if config is None:
|
||||
config = config_pb2.ConfigProto(
|
||||
allow_soft_placement=True,
|
||||
log_device_placement=False,
|
||||
)
|
||||
self._config = config
|
||||
self._thread_local_data = _ThreadLocalData()
|
||||
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
|
||||
self._context_handle = None
|
||||
self._context_devices = None
|
||||
self._post_execution_callbacks = []
|
||||
if config is None:
|
||||
config = config_pb2.ConfigProto()
|
||||
self._config = config
|
||||
self._seed = None
|
||||
self._initialize_lock = threading.Lock()
|
||||
if device_policy is None:
|
||||
@ -641,34 +622,26 @@ class Context(object):
|
||||
else:
|
||||
self._execution_mode = mode
|
||||
|
||||
def get_function_call_options(self):
|
||||
@property
|
||||
def function_call_options(self):
|
||||
"""Returns function call options for current thread.
|
||||
|
||||
Note that the returned object is still referenced by the eager context.
|
||||
|
||||
Returns: the FunctionCallOptions for current thread.
|
||||
"""
|
||||
if self._thread_local_data.function_call_options is None:
|
||||
base_config = config_pb2.ConfigProto()
|
||||
base_config.CopyFrom(self._config)
|
||||
self._thread_local_data.function_call_options = FunctionCallOptions(
|
||||
config_proto=base_config)
|
||||
|
||||
return self._thread_local_data.function_call_options
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def function_call_options(self, set_options_func):
|
||||
"""Context manager for setting function call options of current thread.
|
||||
|
||||
Args:
|
||||
set_options_func: A callable that takes one argument of type
|
||||
FunctionCallOptions. It should set the properties of that
|
||||
FunctionCallOptions.
|
||||
|
||||
Yields:
|
||||
Nothing.
|
||||
"""
|
||||
current_options = self.get_function_call_options()
|
||||
old_options = copy.copy(current_options)
|
||||
try:
|
||||
set_options_func(current_options)
|
||||
yield
|
||||
finally:
|
||||
self._thread_local_data.function_call_options = old_options
|
||||
@function_call_options.setter
|
||||
def function_call_options(self, options):
|
||||
"""Returns function call options for current thread."""
|
||||
self._thread_local_data.function_call_options = options
|
||||
|
||||
def async_wait(self):
|
||||
"""Waits for ops dispatched in ASYNC mode to finish."""
|
||||
@ -1083,22 +1056,26 @@ def execution_mode(mode):
|
||||
|
||||
|
||||
@tf_export("experimental.function_executor_type")
|
||||
@tf_contextlib.contextmanager
|
||||
def function_executor_type(executor_type):
|
||||
"""Context manager for setting the executor of eagar defined functions.
|
||||
"""Context manager for setting the executor of eager defined functions.
|
||||
|
||||
Eager defined functions are functions decorated by tf.contrib.eager.defun.
|
||||
|
||||
Args:
|
||||
executor_type: a string for the name of the executor to be used
|
||||
to execute functions defined by tf.contrib.eager.defun.
|
||||
executor_type: a string for the name of the executor to be used to execute
|
||||
functions defined by tf.contrib.eager.defun.
|
||||
|
||||
Returns:
|
||||
Yields:
|
||||
Context manager for setting the executor of eager defined functions.
|
||||
"""
|
||||
def _set_options_func(options):
|
||||
options.executor_type = executor_type
|
||||
|
||||
return context().function_call_options(_set_options_func)
|
||||
current_options = context().function_call_options
|
||||
old_options = copy.copy(current_options)
|
||||
try:
|
||||
current_options.executor_type = executor_type
|
||||
yield
|
||||
finally:
|
||||
context().function_call_options = old_options
|
||||
|
||||
|
||||
def async_wait():
|
||||
@ -1160,25 +1137,6 @@ def export_run_metadata():
|
||||
return context().export_run_metadata()
|
||||
|
||||
|
||||
def function_config_proto(config_proto):
|
||||
"""Context manager for setting the grappler rewrite config.
|
||||
|
||||
This config is used by Grappler when optimizing the function graph.
|
||||
|
||||
Args:
|
||||
config_proto: a `config_pb2.ConfigProto` proto or
|
||||
a serialized string of that proto or None. If None, the default instance
|
||||
of `config_pb2.ConfigProto` will be used.
|
||||
|
||||
Returns:
|
||||
A context manager.
|
||||
"""
|
||||
def _set_options_func(options):
|
||||
options.config_proto_serialized = config_proto
|
||||
|
||||
return context().function_call_options(_set_options_func)
|
||||
|
||||
|
||||
def set_server_def(server_def):
|
||||
context().set_server_def(server_def)
|
||||
|
||||
|
@ -397,7 +397,7 @@ class _EagerDefinedFunction(object):
|
||||
"Arguments and signature arguments do not match: %s %s " %
|
||||
(len(args), len(list(self.signature.input_arg))))
|
||||
|
||||
function_call_options = ctx.get_function_call_options()
|
||||
function_call_options = ctx.function_call_options
|
||||
if function_call_options.config_proto_serialized is None:
|
||||
config = function_utils.get_disabled_rewriter_config()
|
||||
else:
|
||||
|
@ -234,20 +234,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
else:
|
||||
self.assertEqual(got_shape[0], (None,))
|
||||
|
||||
def testWastedAdd(self):
|
||||
|
||||
@def_function.function()
|
||||
def add(x, y):
|
||||
_ = x * y
|
||||
return x + y
|
||||
|
||||
# The default config allows all rewrites.
|
||||
config_proto = config_pb2.ConfigProto()
|
||||
|
||||
with context.function_config_proto(config_proto):
|
||||
t = constant_op.constant(1.0)
|
||||
self.assertAllEqual(add(t, t).numpy(), 2.0)
|
||||
|
||||
def testNoHash(self):
|
||||
|
||||
@def_function.function()
|
||||
|
@ -19,12 +19,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -165,9 +167,30 @@ class ConfigTest(test.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
config.set_inter_op_parallelism_threads(1)
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@reset_eager
|
||||
def testEnableSoftPlacement(self):
|
||||
def testSoftPlacement(self):
|
||||
self.assertEqual(config.get_soft_device_placement(), True)
|
||||
|
||||
@def_function.function
|
||||
def mod():
|
||||
with ops.device('/device:GPU:0'):
|
||||
a = constant_op.constant(1.0)
|
||||
b = constant_op.constant(1.0)
|
||||
return math_ops.mod(a, b)
|
||||
|
||||
# Since soft placement is enabled, the mod operation should work with CPU
|
||||
mod()
|
||||
|
||||
config.set_soft_device_placement(False)
|
||||
self.assertEqual(config.get_soft_device_placement(), False)
|
||||
self.assertEqual(
|
||||
config.get_soft_device_placement(),
|
||||
context.context().soft_device_placement)
|
||||
|
||||
# Since soft placement is disabled, the mod operation should fail on GPU
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
mod()
|
||||
|
||||
config.set_soft_device_placement(True)
|
||||
self.assertEqual(config.get_soft_device_placement(), True)
|
||||
@ -175,17 +198,8 @@ class ConfigTest(test.TestCase):
|
||||
config.get_soft_device_placement(),
|
||||
context.context().soft_device_placement)
|
||||
|
||||
config.set_soft_device_placement(False)
|
||||
self.assertEqual(config.get_soft_device_placement(), False)
|
||||
self.assertEqual(
|
||||
config.get_soft_device_placement(),
|
||||
context.context().soft_device_placement)
|
||||
|
||||
constant_op.constant(1)
|
||||
with self.assertRaises(RuntimeError):
|
||||
config.set_soft_device_placement(True)
|
||||
with self.assertRaises(RuntimeError):
|
||||
config.set_soft_device_placement(False)
|
||||
# Since soft placement is re-enabled, the mod operation should work with CPU
|
||||
mod()
|
||||
|
||||
@reset_eager
|
||||
def testLogDevicePlacement(self):
|
||||
|
@ -115,8 +115,8 @@ def maybe_set_lowering_attr(op):
|
||||
op: An `If` or `While` Operation.
|
||||
"""
|
||||
if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
|
||||
context.context().get_function_call_options().executor_type
|
||||
!= "SINGLE_THREADED_EXECUTOR"):
|
||||
context.context().function_call_options.executor_type !=
|
||||
"SINGLE_THREADED_EXECUTOR"):
|
||||
# pylint: disable=protected-access
|
||||
op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
|
||||
# pylint: enable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user