Ensure function_call_options use current config

In addition, we ensure the placer gets the session options. This
requires ensuring that the correct defaults are being set.

PiperOrigin-RevId: 237881614
This commit is contained in:
Gaurav Jain 2019-03-11 14:21:45 -07:00 committed by TensorFlower Gardener
parent f95c2d19ed
commit df3a9e555f
6 changed files with 63 additions and 109 deletions

View File

@ -581,11 +581,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
DumpGraph("Before calling Placer", graph.get()); DumpGraph("Before calling Placer", graph.get());
// TODO(b/124993244): Smartly merge options in nested defuns, and raise // TODO(b/124993244): Smartly merge options in nested defuns, and raise
// exceptions/warnings in case where nested function call options are ignored. // exceptions/warnings in case where nested function call options are ignored.
// TODO(b/125933502): Currently config proto in function call options is not Placer placer(graph.get(), &device_set, &session_options,
// respected by placer, because placer and config proto has different default
// behaviors (allowing soft placement by default, vs. not allowing it). Pass
// config proto with appropriate default values to placer here.
Placer placer(graph.get(), &device_set, nullptr, /* No session options */
flr->device() /* Default device */); flr->device() /* Default device */);
TF_RETURN_IF_ERROR(placer.Run()); TF_RETURN_IF_ERROR(placer.Run());

View File

@ -134,7 +134,7 @@ class FunctionCallOptions(object):
class _ThreadLocalData(threading.local): class _ThreadLocalData(threading.local):
"""Thread local storage for the eager context.""" """Thread local storage for the eager context."""
def __init__(self, config=None): def __init__(self):
super(_ThreadLocalData, self).__init__() super(_ThreadLocalData, self).__init__()
self.device_spec = _starting_device_spec self.device_spec = _starting_device_spec
self.device_name = "" self.device_name = ""
@ -148,29 +148,7 @@ class _ThreadLocalData(threading.local):
self._ones_rank_cache = None self._ones_rank_cache = None
self._zeros_cache = None self._zeros_cache = None
self.execution_mode = SYNC self.execution_mode = SYNC
self.function_call_options = None
# Default rewriter config corresponds to turning all default grappler
# optimizations on.
self._config = config
self._function_call_options = None
@property
def function_call_options(self):
if self._function_call_options is None:
base_config = config_pb2.ConfigProto()
if self._config is not None:
base_config.MergeFrom(self._config)
self._config = None
self._function_call_options = FunctionCallOptions(
config_proto=base_config)
return self._function_call_options
@function_call_options.setter
def function_call_options(self, function_call_options):
self._function_call_options = function_call_options
self._config = None
@property @property
def ones_rank_cache(self): def ones_rank_cache(self):
@ -283,14 +261,17 @@ class Context(object):
Raises: Raises:
ValueError: If execution_mode is not valid. ValueError: If execution_mode is not valid.
""" """
self._thread_local_data = _ThreadLocalData(config) if config is None:
config = config_pb2.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
)
self._config = config
self._thread_local_data = _ThreadLocalData()
self._context_switches = _ContextSwitchStack(self.executing_eagerly()) self._context_switches = _ContextSwitchStack(self.executing_eagerly())
self._context_handle = None self._context_handle = None
self._context_devices = None self._context_devices = None
self._post_execution_callbacks = [] self._post_execution_callbacks = []
if config is None:
config = config_pb2.ConfigProto()
self._config = config
self._seed = None self._seed = None
self._initialize_lock = threading.Lock() self._initialize_lock = threading.Lock()
if device_policy is None: if device_policy is None:
@ -641,34 +622,26 @@ class Context(object):
else: else:
self._execution_mode = mode self._execution_mode = mode
def get_function_call_options(self): @property
def function_call_options(self):
"""Returns function call options for current thread. """Returns function call options for current thread.
Note that the returned object is still referenced by the eager context. Note that the returned object is still referenced by the eager context.
Returns: the FunctionCallOptions for current thread. Returns: the FunctionCallOptions for current thread.
""" """
if self._thread_local_data.function_call_options is None:
base_config = config_pb2.ConfigProto()
base_config.CopyFrom(self._config)
self._thread_local_data.function_call_options = FunctionCallOptions(
config_proto=base_config)
return self._thread_local_data.function_call_options return self._thread_local_data.function_call_options
@tf_contextlib.contextmanager @function_call_options.setter
def function_call_options(self, set_options_func): def function_call_options(self, options):
"""Context manager for setting function call options of current thread. """Returns function call options for current thread."""
self._thread_local_data.function_call_options = options
Args:
set_options_func: A callable that takes one argument of type
FunctionCallOptions. It should set the properties of that
FunctionCallOptions.
Yields:
Nothing.
"""
current_options = self.get_function_call_options()
old_options = copy.copy(current_options)
try:
set_options_func(current_options)
yield
finally:
self._thread_local_data.function_call_options = old_options
def async_wait(self): def async_wait(self):
"""Waits for ops dispatched in ASYNC mode to finish.""" """Waits for ops dispatched in ASYNC mode to finish."""
@ -1083,22 +1056,26 @@ def execution_mode(mode):
@tf_export("experimental.function_executor_type") @tf_export("experimental.function_executor_type")
@tf_contextlib.contextmanager
def function_executor_type(executor_type): def function_executor_type(executor_type):
"""Context manager for setting the executor of eagar defined functions. """Context manager for setting the executor of eager defined functions.
Eager defined functions are functions decorated by tf.contrib.eager.defun. Eager defined functions are functions decorated by tf.contrib.eager.defun.
Args: Args:
executor_type: a string for the name of the executor to be used executor_type: a string for the name of the executor to be used to execute
to execute functions defined by tf.contrib.eager.defun. functions defined by tf.contrib.eager.defun.
Returns: Yields:
Context manager for setting the executor of eager defined functions. Context manager for setting the executor of eager defined functions.
""" """
def _set_options_func(options): current_options = context().function_call_options
options.executor_type = executor_type old_options = copy.copy(current_options)
try:
return context().function_call_options(_set_options_func) current_options.executor_type = executor_type
yield
finally:
context().function_call_options = old_options
def async_wait(): def async_wait():
@ -1160,25 +1137,6 @@ def export_run_metadata():
return context().export_run_metadata() return context().export_run_metadata()
def function_config_proto(config_proto):
"""Context manager for setting the grappler rewrite config.
This config is used by Grappler when optimizing the function graph.
Args:
config_proto: a `config_pb2.ConfigProto` proto or
a serialized string of that proto or None. If None, the default instance
of `config_pb2.ConfigProto` will be used.
Returns:
A context manager.
"""
def _set_options_func(options):
options.config_proto_serialized = config_proto
return context().function_call_options(_set_options_func)
def set_server_def(server_def): def set_server_def(server_def):
context().set_server_def(server_def) context().set_server_def(server_def)

View File

@ -397,7 +397,7 @@ class _EagerDefinedFunction(object):
"Arguments and signature arguments do not match: %s %s " % "Arguments and signature arguments do not match: %s %s " %
(len(args), len(list(self.signature.input_arg)))) (len(args), len(list(self.signature.input_arg))))
function_call_options = ctx.get_function_call_options() function_call_options = ctx.function_call_options
if function_call_options.config_proto_serialized is None: if function_call_options.config_proto_serialized is None:
config = function_utils.get_disabled_rewriter_config() config = function_utils.get_disabled_rewriter_config()
else: else:

View File

@ -234,20 +234,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
else: else:
self.assertEqual(got_shape[0], (None,)) self.assertEqual(got_shape[0], (None,))
def testWastedAdd(self):
@def_function.function()
def add(x, y):
_ = x * y
return x + y
# The default config allows all rewrites.
config_proto = config_pb2.ConfigProto()
with context.function_config_proto(config_proto):
t = constant_op.constant(1.0)
self.assertAllEqual(add(t, t).numpy(), 2.0)
def testNoHash(self): def testNoHash(self):
@def_function.function() @def_function.function()

View File

@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -165,9 +167,30 @@ class ConfigTest(test.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
config.set_inter_op_parallelism_threads(1) config.set_inter_op_parallelism_threads(1)
@test_util.run_gpu_only
@reset_eager @reset_eager
def testEnableSoftPlacement(self): def testSoftPlacement(self):
self.assertEqual(config.get_soft_device_placement(), True)
@def_function.function
def mod():
with ops.device('/device:GPU:0'):
a = constant_op.constant(1.0)
b = constant_op.constant(1.0)
return math_ops.mod(a, b)
# Since soft placement is enabled, the mod operation should work with CPU
mod()
config.set_soft_device_placement(False)
self.assertEqual(config.get_soft_device_placement(), False) self.assertEqual(config.get_soft_device_placement(), False)
self.assertEqual(
config.get_soft_device_placement(),
context.context().soft_device_placement)
# Since soft placement is disabled, the mod operation should fail on GPU
with self.assertRaises(errors.InvalidArgumentError):
mod()
config.set_soft_device_placement(True) config.set_soft_device_placement(True)
self.assertEqual(config.get_soft_device_placement(), True) self.assertEqual(config.get_soft_device_placement(), True)
@ -175,17 +198,8 @@ class ConfigTest(test.TestCase):
config.get_soft_device_placement(), config.get_soft_device_placement(),
context.context().soft_device_placement) context.context().soft_device_placement)
config.set_soft_device_placement(False) # Since soft placement is re-enabled, the mod operation should work with CPU
self.assertEqual(config.get_soft_device_placement(), False) mod()
self.assertEqual(
config.get_soft_device_placement(),
context.context().soft_device_placement)
constant_op.constant(1)
with self.assertRaises(RuntimeError):
config.set_soft_device_placement(True)
with self.assertRaises(RuntimeError):
config.set_soft_device_placement(False)
@reset_eager @reset_eager
def testLogDevicePlacement(self): def testLogDevicePlacement(self):

View File

@ -115,8 +115,8 @@ def maybe_set_lowering_attr(op):
op: An `If` or `While` Operation. op: An `If` or `While` Operation.
""" """
if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
context.context().get_function_call_options().executor_type context.context().function_call_options.executor_type !=
!= "SINGLE_THREADED_EXECUTOR"): "SINGLE_THREADED_EXECUTOR"):
# pylint: disable=protected-access # pylint: disable=protected-access
op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access # pylint: enable=protected-access