Ensure function_call_options use current config

In addition, we ensure the placer gets the session options. This
requires ensuring that the correct defaults are being set.

PiperOrigin-RevId: 237881614
This commit is contained in:
Gaurav Jain 2019-03-11 14:21:45 -07:00 committed by TensorFlower Gardener
parent f95c2d19ed
commit df3a9e555f
6 changed files with 63 additions and 109 deletions

View File

@ -581,11 +581,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
DumpGraph("Before calling Placer", graph.get());
// TODO(b/124993244): Smartly merge options in nested defuns, and raise
// exceptions/warnings in case where nested function call options are ignored.
// TODO(b/125933502): Currently config proto in function call options is not
// respected by placer, because placer and config proto has different default
// behaviors (allowing soft placement by default, vs. not allowing it). Pass
// config proto with appropriate default values to placer here.
Placer placer(graph.get(), &device_set, nullptr, /* No session options */
Placer placer(graph.get(), &device_set, &session_options,
flr->device() /* Default device */);
TF_RETURN_IF_ERROR(placer.Run());

View File

@ -134,7 +134,7 @@ class FunctionCallOptions(object):
class _ThreadLocalData(threading.local):
"""Thread local storage for the eager context."""
def __init__(self, config=None):
def __init__(self):
super(_ThreadLocalData, self).__init__()
self.device_spec = _starting_device_spec
self.device_name = ""
@ -148,29 +148,7 @@ class _ThreadLocalData(threading.local):
self._ones_rank_cache = None
self._zeros_cache = None
self.execution_mode = SYNC
# Default rewriter config corresponds to turning all default grappler
# optimizations on.
self._config = config
self._function_call_options = None
@property
def function_call_options(self):
if self._function_call_options is None:
base_config = config_pb2.ConfigProto()
if self._config is not None:
base_config.MergeFrom(self._config)
self._config = None
self._function_call_options = FunctionCallOptions(
config_proto=base_config)
return self._function_call_options
@function_call_options.setter
def function_call_options(self, function_call_options):
self._function_call_options = function_call_options
self._config = None
self.function_call_options = None
@property
def ones_rank_cache(self):
@ -283,14 +261,17 @@ class Context(object):
Raises:
ValueError: If execution_mode is not valid.
"""
self._thread_local_data = _ThreadLocalData(config)
if config is None:
config = config_pb2.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
)
self._config = config
self._thread_local_data = _ThreadLocalData()
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
self._context_handle = None
self._context_devices = None
self._post_execution_callbacks = []
if config is None:
config = config_pb2.ConfigProto()
self._config = config
self._seed = None
self._initialize_lock = threading.Lock()
if device_policy is None:
@ -641,34 +622,26 @@ class Context(object):
else:
self._execution_mode = mode
def get_function_call_options(self):
@property
def function_call_options(self):
"""Returns function call options for current thread.
Note that the returned object is still referenced by the eager context.
Returns: the FunctionCallOptions for current thread.
"""
if self._thread_local_data.function_call_options is None:
base_config = config_pb2.ConfigProto()
base_config.CopyFrom(self._config)
self._thread_local_data.function_call_options = FunctionCallOptions(
config_proto=base_config)
return self._thread_local_data.function_call_options
@tf_contextlib.contextmanager
def function_call_options(self, set_options_func):
"""Context manager for setting function call options of current thread.
Args:
set_options_func: A callable that takes one argument of type
FunctionCallOptions. It should set the properties of that
FunctionCallOptions.
Yields:
Nothing.
"""
current_options = self.get_function_call_options()
old_options = copy.copy(current_options)
try:
set_options_func(current_options)
yield
finally:
self._thread_local_data.function_call_options = old_options
@function_call_options.setter
def function_call_options(self, options):
"""Returns function call options for current thread."""
self._thread_local_data.function_call_options = options
def async_wait(self):
"""Waits for ops dispatched in ASYNC mode to finish."""
@ -1083,22 +1056,26 @@ def execution_mode(mode):
@tf_export("experimental.function_executor_type")
@tf_contextlib.contextmanager
def function_executor_type(executor_type):
"""Context manager for setting the executor of eagar defined functions.
"""Context manager for setting the executor of eager defined functions.
Eager defined functions are functions decorated by tf.contrib.eager.defun.
Args:
executor_type: a string for the name of the executor to be used
to execute functions defined by tf.contrib.eager.defun.
executor_type: a string for the name of the executor to be used to execute
functions defined by tf.contrib.eager.defun.
Returns:
Yields:
Context manager for setting the executor of eager defined functions.
"""
def _set_options_func(options):
options.executor_type = executor_type
return context().function_call_options(_set_options_func)
current_options = context().function_call_options
old_options = copy.copy(current_options)
try:
current_options.executor_type = executor_type
yield
finally:
context().function_call_options = old_options
def async_wait():
@ -1160,25 +1137,6 @@ def export_run_metadata():
return context().export_run_metadata()
def function_config_proto(config_proto):
"""Context manager for setting the grappler rewrite config.
This config is used by Grappler when optimizing the function graph.
Args:
config_proto: a `config_pb2.ConfigProto` proto or
a serialized string of that proto or None. If None, the default instance
of `config_pb2.ConfigProto` will be used.
Returns:
A context manager.
"""
def _set_options_func(options):
options.config_proto_serialized = config_proto
return context().function_call_options(_set_options_func)
def set_server_def(server_def):
context().set_server_def(server_def)

View File

@ -397,7 +397,7 @@ class _EagerDefinedFunction(object):
"Arguments and signature arguments do not match: %s %s " %
(len(args), len(list(self.signature.input_arg))))
function_call_options = ctx.get_function_call_options()
function_call_options = ctx.function_call_options
if function_call_options.config_proto_serialized is None:
config = function_utils.get_disabled_rewriter_config()
else:

View File

@ -234,20 +234,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual(got_shape[0], (None,))
def testWastedAdd(self):
@def_function.function()
def add(x, y):
_ = x * y
return x + y
# The default config allows all rewrites.
config_proto = config_pb2.ConfigProto()
with context.function_config_proto(config_proto):
t = constant_op.constant(1.0)
self.assertAllEqual(add(t, t).numpy(), 2.0)
def testNoHash(self):
@def_function.function()

View File

@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -165,9 +167,30 @@ class ConfigTest(test.TestCase):
with self.assertRaises(RuntimeError):
config.set_inter_op_parallelism_threads(1)
@test_util.run_gpu_only
@reset_eager
def testEnableSoftPlacement(self):
def testSoftPlacement(self):
self.assertEqual(config.get_soft_device_placement(), True)
@def_function.function
def mod():
with ops.device('/device:GPU:0'):
a = constant_op.constant(1.0)
b = constant_op.constant(1.0)
return math_ops.mod(a, b)
# Since soft placement is enabled, the mod operation should work with CPU
mod()
config.set_soft_device_placement(False)
self.assertEqual(config.get_soft_device_placement(), False)
self.assertEqual(
config.get_soft_device_placement(),
context.context().soft_device_placement)
# Since soft placement is disabled, the mod operation should fail on GPU
with self.assertRaises(errors.InvalidArgumentError):
mod()
config.set_soft_device_placement(True)
self.assertEqual(config.get_soft_device_placement(), True)
@ -175,17 +198,8 @@ class ConfigTest(test.TestCase):
config.get_soft_device_placement(),
context.context().soft_device_placement)
config.set_soft_device_placement(False)
self.assertEqual(config.get_soft_device_placement(), False)
self.assertEqual(
config.get_soft_device_placement(),
context.context().soft_device_placement)
constant_op.constant(1)
with self.assertRaises(RuntimeError):
config.set_soft_device_placement(True)
with self.assertRaises(RuntimeError):
config.set_soft_device_placement(False)
# Since soft placement is re-enabled, the mod operation should work with CPU
mod()
@reset_eager
def testLogDevicePlacement(self):

View File

@ -115,8 +115,8 @@ def maybe_set_lowering_attr(op):
op: An `If` or `While` Operation.
"""
if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and
context.context().get_function_call_options().executor_type
!= "SINGLE_THREADED_EXECUTOR"):
context.context().function_call_options.executor_type !=
"SINGLE_THREADED_EXECUTOR"):
# pylint: disable=protected-access
op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access