Ensure function_call_options use current config
In addition, we ensure the placer gets the session options. This requires ensuring that the correct defaults are being set. PiperOrigin-RevId: 237881614
This commit is contained in:
parent
f95c2d19ed
commit
df3a9e555f
@ -581,11 +581,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
DumpGraph("Before calling Placer", graph.get());
|
DumpGraph("Before calling Placer", graph.get());
|
||||||
// TODO(b/124993244): Smartly merge options in nested defuns, and raise
|
// TODO(b/124993244): Smartly merge options in nested defuns, and raise
|
||||||
// exceptions/warnings in case where nested function call options are ignored.
|
// exceptions/warnings in case where nested function call options are ignored.
|
||||||
// TODO(b/125933502): Currently config proto in function call options is not
|
Placer placer(graph.get(), &device_set, &session_options,
|
||||||
// respected by placer, because placer and config proto has different default
|
|
||||||
// behaviors (allowing soft placement by default, vs. not allowing it). Pass
|
|
||||||
// config proto with appropriate default values to placer here.
|
|
||||||
Placer placer(graph.get(), &device_set, nullptr, /* No session options */
|
|
||||||
flr->device() /* Default device */);
|
flr->device() /* Default device */);
|
||||||
TF_RETURN_IF_ERROR(placer.Run());
|
TF_RETURN_IF_ERROR(placer.Run());
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ class FunctionCallOptions(object):
|
|||||||
class _ThreadLocalData(threading.local):
|
class _ThreadLocalData(threading.local):
|
||||||
"""Thread local storage for the eager context."""
|
"""Thread local storage for the eager context."""
|
||||||
|
|
||||||
def __init__(self, config=None):
|
def __init__(self):
|
||||||
super(_ThreadLocalData, self).__init__()
|
super(_ThreadLocalData, self).__init__()
|
||||||
self.device_spec = _starting_device_spec
|
self.device_spec = _starting_device_spec
|
||||||
self.device_name = ""
|
self.device_name = ""
|
||||||
@ -148,29 +148,7 @@ class _ThreadLocalData(threading.local):
|
|||||||
self._ones_rank_cache = None
|
self._ones_rank_cache = None
|
||||||
self._zeros_cache = None
|
self._zeros_cache = None
|
||||||
self.execution_mode = SYNC
|
self.execution_mode = SYNC
|
||||||
|
self.function_call_options = None
|
||||||
# Default rewriter config corresponds to turning all default grappler
|
|
||||||
# optimizations on.
|
|
||||||
self._config = config
|
|
||||||
|
|
||||||
self._function_call_options = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def function_call_options(self):
|
|
||||||
if self._function_call_options is None:
|
|
||||||
base_config = config_pb2.ConfigProto()
|
|
||||||
if self._config is not None:
|
|
||||||
base_config.MergeFrom(self._config)
|
|
||||||
self._config = None
|
|
||||||
self._function_call_options = FunctionCallOptions(
|
|
||||||
config_proto=base_config)
|
|
||||||
|
|
||||||
return self._function_call_options
|
|
||||||
|
|
||||||
@function_call_options.setter
|
|
||||||
def function_call_options(self, function_call_options):
|
|
||||||
self._function_call_options = function_call_options
|
|
||||||
self._config = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ones_rank_cache(self):
|
def ones_rank_cache(self):
|
||||||
@ -283,14 +261,17 @@ class Context(object):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If execution_mode is not valid.
|
ValueError: If execution_mode is not valid.
|
||||||
"""
|
"""
|
||||||
self._thread_local_data = _ThreadLocalData(config)
|
if config is None:
|
||||||
|
config = config_pb2.ConfigProto(
|
||||||
|
allow_soft_placement=True,
|
||||||
|
log_device_placement=False,
|
||||||
|
)
|
||||||
|
self._config = config
|
||||||
|
self._thread_local_data = _ThreadLocalData()
|
||||||
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
|
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
|
||||||
self._context_handle = None
|
self._context_handle = None
|
||||||
self._context_devices = None
|
self._context_devices = None
|
||||||
self._post_execution_callbacks = []
|
self._post_execution_callbacks = []
|
||||||
if config is None:
|
|
||||||
config = config_pb2.ConfigProto()
|
|
||||||
self._config = config
|
|
||||||
self._seed = None
|
self._seed = None
|
||||||
self._initialize_lock = threading.Lock()
|
self._initialize_lock = threading.Lock()
|
||||||
if device_policy is None:
|
if device_policy is None:
|
||||||
@ -641,34 +622,26 @@ class Context(object):
|
|||||||
else:
|
else:
|
||||||
self._execution_mode = mode
|
self._execution_mode = mode
|
||||||
|
|
||||||
def get_function_call_options(self):
|
@property
|
||||||
|
def function_call_options(self):
|
||||||
"""Returns function call options for current thread.
|
"""Returns function call options for current thread.
|
||||||
|
|
||||||
Note that the returned object is still referenced by the eager context.
|
Note that the returned object is still referenced by the eager context.
|
||||||
|
|
||||||
Returns: the FunctionCallOptions for current thread.
|
Returns: the FunctionCallOptions for current thread.
|
||||||
"""
|
"""
|
||||||
|
if self._thread_local_data.function_call_options is None:
|
||||||
|
base_config = config_pb2.ConfigProto()
|
||||||
|
base_config.CopyFrom(self._config)
|
||||||
|
self._thread_local_data.function_call_options = FunctionCallOptions(
|
||||||
|
config_proto=base_config)
|
||||||
|
|
||||||
return self._thread_local_data.function_call_options
|
return self._thread_local_data.function_call_options
|
||||||
|
|
||||||
@tf_contextlib.contextmanager
|
@function_call_options.setter
|
||||||
def function_call_options(self, set_options_func):
|
def function_call_options(self, options):
|
||||||
"""Context manager for setting function call options of current thread.
|
"""Returns function call options for current thread."""
|
||||||
|
self._thread_local_data.function_call_options = options
|
||||||
Args:
|
|
||||||
set_options_func: A callable that takes one argument of type
|
|
||||||
FunctionCallOptions. It should set the properties of that
|
|
||||||
FunctionCallOptions.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Nothing.
|
|
||||||
"""
|
|
||||||
current_options = self.get_function_call_options()
|
|
||||||
old_options = copy.copy(current_options)
|
|
||||||
try:
|
|
||||||
set_options_func(current_options)
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
self._thread_local_data.function_call_options = old_options
|
|
||||||
|
|
||||||
def async_wait(self):
|
def async_wait(self):
|
||||||
"""Waits for ops dispatched in ASYNC mode to finish."""
|
"""Waits for ops dispatched in ASYNC mode to finish."""
|
||||||
@ -1083,22 +1056,26 @@ def execution_mode(mode):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("experimental.function_executor_type")
|
@tf_export("experimental.function_executor_type")
|
||||||
|
@tf_contextlib.contextmanager
|
||||||
def function_executor_type(executor_type):
|
def function_executor_type(executor_type):
|
||||||
"""Context manager for setting the executor of eagar defined functions.
|
"""Context manager for setting the executor of eager defined functions.
|
||||||
|
|
||||||
Eager defined functions are functions decorated by tf.contrib.eager.defun.
|
Eager defined functions are functions decorated by tf.contrib.eager.defun.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
executor_type: a string for the name of the executor to be used
|
executor_type: a string for the name of the executor to be used to execute
|
||||||
to execute functions defined by tf.contrib.eager.defun.
|
functions defined by tf.contrib.eager.defun.
|
||||||
|
|
||||||
Returns:
|
Yields:
|
||||||
Context manager for setting the executor of eager defined functions.
|
Context manager for setting the executor of eager defined functions.
|
||||||
"""
|
"""
|
||||||
def _set_options_func(options):
|
current_options = context().function_call_options
|
||||||
options.executor_type = executor_type
|
old_options = copy.copy(current_options)
|
||||||
|
try:
|
||||||
return context().function_call_options(_set_options_func)
|
current_options.executor_type = executor_type
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
context().function_call_options = old_options
|
||||||
|
|
||||||
|
|
||||||
def async_wait():
|
def async_wait():
|
||||||
@ -1160,25 +1137,6 @@ def export_run_metadata():
|
|||||||
return context().export_run_metadata()
|
return context().export_run_metadata()
|
||||||
|
|
||||||
|
|
||||||
def function_config_proto(config_proto):
|
|
||||||
"""Context manager for setting the grappler rewrite config.
|
|
||||||
|
|
||||||
This config is used by Grappler when optimizing the function graph.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_proto: a `config_pb2.ConfigProto` proto or
|
|
||||||
a serialized string of that proto or None. If None, the default instance
|
|
||||||
of `config_pb2.ConfigProto` will be used.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A context manager.
|
|
||||||
"""
|
|
||||||
def _set_options_func(options):
|
|
||||||
options.config_proto_serialized = config_proto
|
|
||||||
|
|
||||||
return context().function_call_options(_set_options_func)
|
|
||||||
|
|
||||||
|
|
||||||
def set_server_def(server_def):
|
def set_server_def(server_def):
|
||||||
context().set_server_def(server_def)
|
context().set_server_def(server_def)
|
||||||
|
|
||||||
|
@ -397,7 +397,7 @@ class _EagerDefinedFunction(object):
|
|||||||
"Arguments and signature arguments do not match: %s %s " %
|
"Arguments and signature arguments do not match: %s %s " %
|
||||||
(len(args), len(list(self.signature.input_arg))))
|
(len(args), len(list(self.signature.input_arg))))
|
||||||
|
|
||||||
function_call_options = ctx.get_function_call_options()
|
function_call_options = ctx.function_call_options
|
||||||
if function_call_options.config_proto_serialized is None:
|
if function_call_options.config_proto_serialized is None:
|
||||||
config = function_utils.get_disabled_rewriter_config()
|
config = function_utils.get_disabled_rewriter_config()
|
||||||
else:
|
else:
|
||||||
|
@ -234,20 +234,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertEqual(got_shape[0], (None,))
|
self.assertEqual(got_shape[0], (None,))
|
||||||
|
|
||||||
def testWastedAdd(self):
|
|
||||||
|
|
||||||
@def_function.function()
|
|
||||||
def add(x, y):
|
|
||||||
_ = x * y
|
|
||||||
return x + y
|
|
||||||
|
|
||||||
# The default config allows all rewrites.
|
|
||||||
config_proto = config_pb2.ConfigProto()
|
|
||||||
|
|
||||||
with context.function_config_proto(config_proto):
|
|
||||||
t = constant_op.constant(1.0)
|
|
||||||
self.assertAllEqual(add(t, t).numpy(), 2.0)
|
|
||||||
|
|
||||||
def testNoHash(self):
|
def testNoHash(self):
|
||||||
|
|
||||||
@def_function.function()
|
@def_function.function()
|
||||||
|
@ -19,12 +19,14 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -165,15 +167,20 @@ class ConfigTest(test.TestCase):
|
|||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
config.set_inter_op_parallelism_threads(1)
|
config.set_inter_op_parallelism_threads(1)
|
||||||
|
|
||||||
|
@test_util.run_gpu_only
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testEnableSoftPlacement(self):
|
def testSoftPlacement(self):
|
||||||
self.assertEqual(config.get_soft_device_placement(), False)
|
|
||||||
|
|
||||||
config.set_soft_device_placement(True)
|
|
||||||
self.assertEqual(config.get_soft_device_placement(), True)
|
self.assertEqual(config.get_soft_device_placement(), True)
|
||||||
self.assertEqual(
|
|
||||||
config.get_soft_device_placement(),
|
@def_function.function
|
||||||
context.context().soft_device_placement)
|
def mod():
|
||||||
|
with ops.device('/device:GPU:0'):
|
||||||
|
a = constant_op.constant(1.0)
|
||||||
|
b = constant_op.constant(1.0)
|
||||||
|
return math_ops.mod(a, b)
|
||||||
|
|
||||||
|
# Since soft placement is enabled, the mod operation should work with CPU
|
||||||
|
mod()
|
||||||
|
|
||||||
config.set_soft_device_placement(False)
|
config.set_soft_device_placement(False)
|
||||||
self.assertEqual(config.get_soft_device_placement(), False)
|
self.assertEqual(config.get_soft_device_placement(), False)
|
||||||
@ -181,11 +188,18 @@ class ConfigTest(test.TestCase):
|
|||||||
config.get_soft_device_placement(),
|
config.get_soft_device_placement(),
|
||||||
context.context().soft_device_placement)
|
context.context().soft_device_placement)
|
||||||
|
|
||||||
constant_op.constant(1)
|
# Since soft placement is disabled, the mod operation should fail on GPU
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
config.set_soft_device_placement(True)
|
mod()
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
config.set_soft_device_placement(False)
|
config.set_soft_device_placement(True)
|
||||||
|
self.assertEqual(config.get_soft_device_placement(), True)
|
||||||
|
self.assertEqual(
|
||||||
|
config.get_soft_device_placement(),
|
||||||
|
context.context().soft_device_placement)
|
||||||
|
|
||||||
|
# Since soft placement is re-enabled, the mod operation should work with CPU
|
||||||
|
mod()
|
||||||
|
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testLogDevicePlacement(self):
|
def testLogDevicePlacement(self):
|
||||||
|
@ -115,8 +115,8 @@ def maybe_set_lowering_attr(op):
|
|||||||
op: An `If` or `While` Operation.
|
op: An `If` or `While` Operation.
|
||||||
"""
|
"""
|
||||||
if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
|
if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
|
||||||
context.context().get_function_call_options().executor_type
|
context.context().function_call_options.executor_type !=
|
||||||
!= "SINGLE_THREADED_EXECUTOR"):
|
"SINGLE_THREADED_EXECUTOR"):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
|
op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
Loading…
Reference in New Issue
Block a user