From df3a9e555f666c35e912b0457924ea124da45997 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Mon, 11 Mar 2019 14:21:45 -0700 Subject: [PATCH] Ensure function_call_options use current config In addition, we ensure the placer gets the session options. This requires ensuring that the correct defaults are being set. PiperOrigin-RevId: 237881614 --- .../process_function_library_runtime.cc | 6 +- tensorflow/python/eager/context.py | 108 ++++++------------ tensorflow/python/eager/function.py | 2 +- tensorflow/python/eager/function_test.py | 14 --- tensorflow/python/framework/config_test.py | 38 ++++-- tensorflow/python/ops/control_flow_util_v2.py | 4 +- 6 files changed, 63 insertions(+), 109 deletions(-) diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 37f24f2e431..020bb60b4e6 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -581,11 +581,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( DumpGraph("Before calling Placer", graph.get()); // TODO(b/124993244): Smartly merge options in nested defuns, and raise // exceptions/warnings in case where nested function call options are ignored. - // TODO(b/125933502): Currently config proto in function call options is not - // respected by placer, because placer and config proto has different default - // behaviors (allowing soft placement by default, vs. not allowing it). Pass - // config proto with appropriate default values to placer here. - Placer placer(graph.get(), &device_set, nullptr, /* No session options */ + Placer placer(graph.get(), &device_set, &session_options, flr->device() /* Default device */); TF_RETURN_IF_ERROR(placer.Run()); diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 4caf357b4c2..dc068e1a75c 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -134,7 +134,7 @@ class FunctionCallOptions(object): class _ThreadLocalData(threading.local): """Thread local storage for the eager context.""" - def __init__(self, config=None): + def __init__(self): super(_ThreadLocalData, self).__init__() self.device_spec = _starting_device_spec self.device_name = "" @@ -148,29 +148,7 @@ class _ThreadLocalData(threading.local): self._ones_rank_cache = None self._zeros_cache = None self.execution_mode = SYNC - - # Default rewriter config corresponds to turning all default grappler - # optimizations on. - self._config = config - - self._function_call_options = None - - @property - def function_call_options(self): - if self._function_call_options is None: - base_config = config_pb2.ConfigProto() - if self._config is not None: - base_config.MergeFrom(self._config) - self._config = None - self._function_call_options = FunctionCallOptions( - config_proto=base_config) - - return self._function_call_options - - @function_call_options.setter - def function_call_options(self, function_call_options): - self._function_call_options = function_call_options - self._config = None + self.function_call_options = None @property def ones_rank_cache(self): @@ -283,14 +261,17 @@ class Context(object): Raises: ValueError: If execution_mode is not valid. """ - self._thread_local_data = _ThreadLocalData(config) + if config is None: + config = config_pb2.ConfigProto( + allow_soft_placement=True, + log_device_placement=False, + ) + self._config = config + self._thread_local_data = _ThreadLocalData() self._context_switches = _ContextSwitchStack(self.executing_eagerly()) self._context_handle = None self._context_devices = None self._post_execution_callbacks = [] - if config is None: - config = config_pb2.ConfigProto() - self._config = config self._seed = None self._initialize_lock = threading.Lock() if device_policy is None: @@ -641,34 +622,26 @@ class Context(object): else: self._execution_mode = mode - def get_function_call_options(self): + @property + def function_call_options(self): """Returns function call options for current thread. Note that the returned object is still referenced by the eager context. Returns: the FunctionCallOptions for current thread. """ + if self._thread_local_data.function_call_options is None: + base_config = config_pb2.ConfigProto() + base_config.CopyFrom(self._config) + self._thread_local_data.function_call_options = FunctionCallOptions( + config_proto=base_config) + return self._thread_local_data.function_call_options - @tf_contextlib.contextmanager - def function_call_options(self, set_options_func): - """Context manager for setting function call options of current thread. - - Args: - set_options_func: A callable that takes one argument of type - FunctionCallOptions. It should set the properties of that - FunctionCallOptions. - - Yields: - Nothing. - """ - current_options = self.get_function_call_options() - old_options = copy.copy(current_options) - try: - set_options_func(current_options) - yield - finally: - self._thread_local_data.function_call_options = old_options + @function_call_options.setter + def function_call_options(self, options): + """Returns function call options for current thread.""" + self._thread_local_data.function_call_options = options def async_wait(self): """Waits for ops dispatched in ASYNC mode to finish.""" @@ -1083,22 +1056,26 @@ def execution_mode(mode): @tf_export("experimental.function_executor_type") +@tf_contextlib.contextmanager def function_executor_type(executor_type): - """Context manager for setting the executor of eagar defined functions. + """Context manager for setting the executor of eager defined functions. Eager defined functions are functions decorated by tf.contrib.eager.defun. Args: - executor_type: a string for the name of the executor to be used - to execute functions defined by tf.contrib.eager.defun. + executor_type: a string for the name of the executor to be used to execute + functions defined by tf.contrib.eager.defun. - Returns: + Yields: Context manager for setting the executor of eager defined functions. """ - def _set_options_func(options): - options.executor_type = executor_type - - return context().function_call_options(_set_options_func) + current_options = context().function_call_options + old_options = copy.copy(current_options) + try: + current_options.executor_type = executor_type + yield + finally: + context().function_call_options = old_options def async_wait(): @@ -1160,25 +1137,6 @@ def export_run_metadata(): return context().export_run_metadata() -def function_config_proto(config_proto): - """Context manager for setting the grappler rewrite config. - - This config is used by Grappler when optimizing the function graph. - - Args: - config_proto: a `config_pb2.ConfigProto` proto or - a serialized string of that proto or None. If None, the default instance - of `config_pb2.ConfigProto` will be used. - - Returns: - A context manager. - """ - def _set_options_func(options): - options.config_proto_serialized = config_proto - - return context().function_call_options(_set_options_func) - - def set_server_def(server_def): context().set_server_def(server_def) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index d1ed60aa981..00ad03457b5 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -397,7 +397,7 @@ class _EagerDefinedFunction(object): "Arguments and signature arguments do not match: %s %s " % (len(args), len(list(self.signature.input_arg)))) - function_call_options = ctx.get_function_call_options() + function_call_options = ctx.function_call_options if function_call_options.config_proto_serialized is None: config = function_utils.get_disabled_rewriter_config() else: diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index d32a4101465..96184f6656c 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -234,20 +234,6 @@ class FunctionTest(test.TestCase, parameterized.TestCase): else: self.assertEqual(got_shape[0], (None,)) - def testWastedAdd(self): - - @def_function.function() - def add(x, y): - _ = x * y - return x + y - - # The default config allows all rewrites. - config_proto = config_pb2.ConfigProto() - - with context.function_config_proto(config_proto): - t = constant_op.constant(1.0) - self.assertAllEqual(add(t, t).numpy(), 2.0) - def testNoHash(self): @def_function.function() diff --git a/tensorflow/python/framework/config_test.py b/tensorflow/python/framework/config_test.py index e1efe2ac4a5..e7287c84dbb 100644 --- a/tensorflow/python/framework/config_test.py +++ b/tensorflow/python/framework/config_test.py @@ -19,12 +19,14 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -165,15 +167,20 @@ class ConfigTest(test.TestCase): with self.assertRaises(RuntimeError): config.set_inter_op_parallelism_threads(1) + @test_util.run_gpu_only @reset_eager - def testEnableSoftPlacement(self): - self.assertEqual(config.get_soft_device_placement(), False) - - config.set_soft_device_placement(True) + def testSoftPlacement(self): self.assertEqual(config.get_soft_device_placement(), True) - self.assertEqual( - config.get_soft_device_placement(), - context.context().soft_device_placement) + + @def_function.function + def mod(): + with ops.device('/device:GPU:0'): + a = constant_op.constant(1.0) + b = constant_op.constant(1.0) + return math_ops.mod(a, b) + + # Since soft placement is enabled, the mod operation should work with CPU + mod() config.set_soft_device_placement(False) self.assertEqual(config.get_soft_device_placement(), False) @@ -181,11 +188,18 @@ class ConfigTest(test.TestCase): config.get_soft_device_placement(), context.context().soft_device_placement) - constant_op.constant(1) - with self.assertRaises(RuntimeError): - config.set_soft_device_placement(True) - with self.assertRaises(RuntimeError): - config.set_soft_device_placement(False) + # Since soft placement is disabled, the mod operation should fail on GPU + with self.assertRaises(errors.InvalidArgumentError): + mod() + + config.set_soft_device_placement(True) + self.assertEqual(config.get_soft_device_placement(), True) + self.assertEqual( + config.get_soft_device_placement(), + context.context().soft_device_placement) + + # Since soft placement is re-enabled, the mod operation should work with CPU + mod() @reset_eager def testLogDevicePlacement(self): diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py index 17d1074f0fd..cd37419906b 100644 --- a/tensorflow/python/ops/control_flow_util_v2.py +++ b/tensorflow/python/ops/control_flow_util_v2.py @@ -115,8 +115,8 @@ def maybe_set_lowering_attr(op): op: An `If` or `While` Operation. """ if (not control_flow_util.GraphOrParentsInXlaContext(op.graph) and - context.context().get_function_call_options().executor_type - != "SINGLE_THREADED_EXECUTOR"): + context.context().function_call_options.executor_type != + "SINGLE_THREADED_EXECUTOR"): # pylint: disable=protected-access op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access