Expose optimizer options in tf.config

PiperOrigin-RevId: 239084110
This commit is contained in:
Gaurav Jain 2019-03-18 16:22:37 -07:00 committed by TensorFlower Gardener
parent dbfab63703
commit 25c48f50fc
10 changed files with 354 additions and 2 deletions

View File

@ -1020,8 +1020,10 @@ cuda_py_test(
":constant_op", ":constant_op",
":client_testlib", ":client_testlib",
":platform", ":platform",
":test_ops",
":util", ":util",
], ],
tags = ["no_pip"], # test_ops are not available in pip.
xla_enable_strict_auto_jit = True, xla_enable_strict_auto_jit = True,
) )

View File

@ -25,6 +25,7 @@ import random
import threading import threading
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.framework import c_api_util from tensorflow.python.framework import c_api_util
@ -754,6 +755,102 @@ class Context(object):
self._config.gpu_options.allow_growth = enabled self._config.gpu_options.allow_growth = enabled
@property
def optimizer_jit(self):
level = self._config.graph_options.optimizer_options.global_jit_level
return (level == config_pb2.OptimizerOptions.ON_1 or
level == config_pb2.OptimizerOptions.ON_2)
@optimizer_jit.setter
def optimizer_jit(self, enabled):
self._config.graph_options.optimizer_options.global_jit_level = (
config_pb2.OptimizerOptions.ON_1
if enabled else config_pb2.OptimizerOptions.OFF)
self._thread_local_data.function_call_options = None
def get_optimizer_experimental_options(self):
"""Get experimental options for the optimizer.
Returns:
Dictionary of current option values
"""
rewrite_options = self._config.graph_options.rewrite_options
options = {}
def rewriter_toggle(option):
attr = getattr(rewrite_options, option)
if attr != 0:
options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON)
def rewriter_bool(option):
options[option] = getattr(rewrite_options, option)
rewriter_toggle("layout_optimizer")
rewriter_toggle("constant_folding")
rewriter_toggle("shape_optimization")
rewriter_toggle("remapping")
rewriter_toggle("arithmetic_optimization")
rewriter_toggle("dependency_optimization")
rewriter_toggle("loop_optimization")
rewriter_toggle("function_optimization")
rewriter_toggle("debug_stripper")
rewriter_bool("disable_model_pruning")
rewriter_toggle("scoped_allocator_optimization")
rewriter_toggle("pin_to_host_optimization")
rewriter_toggle("implementation_selector")
rewriter_bool("disable_meta_optimizer")
if rewrite_options.min_graph_nodes != 0:
options["min_graph_nodes"] = rewrite_options.min_graph_nodes
return options
def set_optimizer_experimental_options(self, options):
"""Set experimental options for the optimizer.
Args:
options: Dictionary of options to modify
"""
def rewriter_toggle(option):
toggle = options.get(option, None)
if toggle is None:
return
setattr(self._config.graph_options.rewrite_options,
option,
(rewriter_config_pb2.RewriterConfig.ON
if toggle else rewriter_config_pb2.RewriterConfig.OFF))
def rewriter_bool(option):
toggle = options.get(option, None)
if toggle is None:
return
setattr(self._config.graph_options.rewrite_options,
option,
toggle)
rewriter_toggle("layout_optimizer")
rewriter_toggle("constant_folding")
rewriter_toggle("shape_optimization")
rewriter_toggle("remapping")
rewriter_toggle("arithmetic_optimization")
rewriter_toggle("dependency_optimization")
rewriter_toggle("loop_optimization")
rewriter_toggle("function_optimization")
rewriter_toggle("debug_stripper")
rewriter_bool("disable_model_pruning")
rewriter_toggle("scoped_allocator_optimization")
rewriter_toggle("pin_to_host_optimization")
rewriter_toggle("implementation_selector")
rewriter_bool("disable_meta_optimizer")
nodes = options.get("min_graph_nodes", None)
if nodes is not None:
self._config.graph_options.rewrite_options.min_graph_nodes = nodes
self._thread_local_data.function_call_options = None
@property @property
def intra_op_parallelism_threads(self): def intra_op_parallelism_threads(self):
return self._config.intra_op_parallelism_threads return self._config.intra_op_parallelism_threads

View File

@ -122,6 +122,82 @@ def set_inter_op_parallelism_threads(num_threads):
context.context().inter_op_parallelism_threads = num_threads context.context().inter_op_parallelism_threads = num_threads
@tf_export('config.optimizer.get_jit')
def get_optimizer_jit():
"""Get if JIT compilation is enabled.
Note that optimizations are only applied in graph mode, (within tf.function).
Returns:
If JIT compilation is enabled.
"""
return context.context().optimizer_jit
@tf_export('config.optimizer.set_jit')
def set_optimizer_jit(enabled):
"""Set if JIT compilation is enabled.
Args:
enabled: Whether to enable JIT compilation.
"""
context.context().optimizer_jit = enabled
@tf_export('config.optimizer.get_experimental_options')
def get_optimizer_experimental_options():
"""Get experimental optimizer options.
Refer to tf.config.optimizer.set_experimental_options for a list of current
options.
Note that optimizations are only applied in graph mode, (within tf.function).
In addition, as these are experimental options, the list is subject to change.
Returns:
Dictionary of configured experimental optimizer options
"""
return context.context().get_optimizer_experimental_options()
@tf_export('config.optimizer.set_experimental_options')
def set_optimizer_experimental_options(options):
"""Set experimental optimizer options.
Note that optimizations are only applied in graph mode, (within tf.function).
In addition, as these are experimental options, the list is subject to change.
Args:
options: Dictionary of experimental optimizer options to configure.
Valid keys:
- layout_optimizer: Optimize tensor layouts
e.g. This will try to use NCHW layout on GPU which is faster.
- constant_folding: Fold constants
Statically infer the value of tensors when possible, and materialize the
result using constants.
- shape_optimization: Simplify computations made on shapes.
- remapping: Remap subgraphs onto more efficient implementations.
- arithmetic_optimization: Simplify arithmetic ops with common
sub-expression elimination and arithmetic simplification.
- dependency_optimization: Control dependency optimizations. Remove
redundant control dependencies, which may enable other optimization.
This optimizer is also essential for pruning Identity and NoOp nodes.
- loop_optimization: Loop optimizations.
- function_optimization: Function optimizations and inlining.
- debug_stripper: Strips debug-related nodes from the graph.
- disable_model_pruning: Disable removal of unnecessary ops from the graph
- scoped_allocator_optimization: Try to allocate some independent Op
outputs contiguously in order to merge or eliminate downstream Ops.
- pin_to_host_optimization: Force small ops onto the CPU.
- implementation_selector: Enable the swap of kernel implementations based
on the device placement.
- disable_meta_optimizer: Disable the entire meta optimizer.
- min_graph_nodes: The minimum number of nodes in a graph to optimizer.
For smaller graphs, optimization is skipped.
"""
context.context().set_optimizer_experimental_options(options)
@tf_export('config.get_soft_device_placement') @tf_export('config.get_soft_device_placement')
def get_soft_device_placement(): def get_soft_device_placement():
"""Get if soft device placement is enabled. """Get if soft device placement is enabled.
@ -147,7 +223,7 @@ def set_soft_device_placement(enabled):
3. need to co-locate with reftype input(s) which are from CPU 3. need to co-locate with reftype input(s) which are from CPU
Args: Args:
enabled: Whether to enabled soft placement. enabled: Whether to enable soft placement.
""" """
context.context().soft_device_placement = enabled context.context().soft_device_placement = enabled

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import config from tensorflow.python.framework import config
@ -25,9 +27,11 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import compat
def reset_eager(fn): def reset_eager(fn):
@ -42,7 +46,7 @@ def reset_eager(fn):
return wrapper return wrapper
class ConfigTest(test.TestCase): class ConfigTest(test.TestCase, parameterized.TestCase):
@test_util.run_gpu_only @test_util.run_gpu_only
@reset_eager @reset_eager
@ -223,6 +227,131 @@ class ConfigTest(test.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
context.set_log_device_placement(False) context.set_log_device_placement(False)
@test_util.run_gpu_only
@reset_eager
def testJit(self):
self.assertEqual(config.get_optimizer_jit(), False)
# the following function should cause Op fusion to occur. However, there is
# unfortunately no straightforward way to ensure this. We will just have to
# settle for creating a test that can trigger JIT.
@def_function.function
def fun(a, b):
c = a * b
d = c + a
return d
a = constant_op.constant([2., 2.])
b = constant_op.constant([2., 2.])
self.evaluate(fun(a, b))
config.set_optimizer_jit(True)
self.assertEqual(config.get_optimizer_jit(), True)
self.assertEqual(config.get_optimizer_jit(),
context.context().optimizer_jit)
self.evaluate(fun(a, b))
config.set_optimizer_jit(False)
self.assertEqual(config.get_optimizer_jit(), False)
self.assertEqual(config.get_optimizer_jit(),
context.context().optimizer_jit)
self.evaluate(fun(a, b))
@parameterized.named_parameters(
('LayoutOptimizer', 'layout_optimizer'),
('ConstantFolding', 'constant_folding'),
('ShapeOptimization', 'shape_optimization'),
('Remapping', 'remapping'),
('ArithmeticOptimization', 'arithmetic_optimization'),
('DependencyOptimization', 'dependency_optimization'),
('LoopOptimization', 'loop_optimization'),
('FunctionOptimization', 'function_optimization'),
('DebugStripper', 'debug_stripper'),
('ScopedAllocatorOptimization', 'scoped_allocator_optimization'),
('ImplementationSelector', 'implementation_selector'))
@reset_eager
def testOptimizerToggleOption(self, field):
# TODO(b/128531235): Improve testing of option
options = config.get_optimizer_experimental_options()
self.assertIsNone(options.get(field))
config.set_optimizer_experimental_options({field: True})
options[field] = True
self.assertDictEqual(config.get_optimizer_experimental_options(), options)
self.assertDictEqual(
context.context().get_optimizer_experimental_options(), options)
config.set_optimizer_experimental_options({field: False})
options[field] = False
self.assertDictEqual(config.get_optimizer_experimental_options(), options)
self.assertDictEqual(
context.context().get_optimizer_experimental_options(), options)
@parameterized.named_parameters(
('DisableModelPruning', 'disable_model_pruning'),
('DisableMetaOptimizer', 'disable_meta_optimizer'))
@reset_eager
def testOptimizerBoolOption(self, field):
# TODO(b/128531235): Improve testing of option
options = config.get_optimizer_experimental_options()
self.assertFalse(options.get(field))
config.set_optimizer_experimental_options({field: True})
options[field] = True
self.assertDictEqual(config.get_optimizer_experimental_options(), options)
self.assertDictEqual(
context.context().get_optimizer_experimental_options(), options)
config.set_optimizer_experimental_options({field: False})
options[field] = False
self.assertDictEqual(config.get_optimizer_experimental_options(), options)
self.assertDictEqual(
context.context().get_optimizer_experimental_options(), options)
@test_util.run_gpu_only
@reset_eager
def testOptimizerToggleOptionPinToHost(self):
options = config.get_optimizer_experimental_options()
self.assertIsNone(options.get('pin_to_host_optimization'))
@def_function.function
def fun():
op = test_ops.device_placement_op()
return op
# Force optimizer to run for all graphs
config.set_optimizer_experimental_options({'min_graph_nodes': -1})
options['min_graph_nodes'] = -1
# Since pin to host is disabled, the operation should go on GPU
gpu = self.evaluate(fun())
self.assertIn(compat.as_bytes('GPU'), gpu)
config.set_optimizer_experimental_options(
{'pin_to_host_optimization': True})
options['pin_to_host_optimization'] = True
self.assertDictEqual(config.get_optimizer_experimental_options(), options)
self.assertDictEqual(
context.context().get_optimizer_experimental_options(), options)
# Since pin to host is enabled, the operation should go on CPU
cpu = self.evaluate(fun())
self.assertIn(compat.as_bytes('CPU'), cpu)
config.set_optimizer_experimental_options(
{'pin_to_host_optimization': False})
options['pin_to_host_optimization'] = False
self.assertDictEqual(config.get_optimizer_experimental_options(), options)
self.assertDictEqual(
context.context().get_optimizer_experimental_options(), options)
# Since pin to host is disabled again, the operation should go on GPU
gpu2 = self.evaluate(fun())
self.assertIn(compat.as_bytes('GPU'), gpu2)
if __name__ == '__main__': if __name__ == '__main__':
ops.enable_eager_execution() ops.enable_eager_execution()

View File

@ -12,6 +12,7 @@ TENSORFLOW_API_INIT_FILES = [
"config/__init__.py", "config/__init__.py",
"config/experimental/__init__.py", "config/experimental/__init__.py",
"config/gpu/__init__.py", "config/gpu/__init__.py",
"config/optimizer/__init__.py",
"config/threading/__init__.py", "config/threading/__init__.py",
"data/__init__.py", "data/__init__.py",
"data/experimental/__init__.py", "data/experimental/__init__.py",

View File

@ -13,6 +13,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"config/__init__.py", "config/__init__.py",
"config/experimental/__init__.py", "config/experimental/__init__.py",
"config/gpu/__init__.py", "config/gpu/__init__.py",
"config/optimizer/__init__.py",
"config/threading/__init__.py", "config/threading/__init__.py",
"data/__init__.py", "data/__init__.py",
"data/experimental/__init__.py", "data/experimental/__init__.py",

View File

@ -0,0 +1,19 @@
path: "tensorflow.config.optimizer"
tf_module {
member_method {
name: "get_experimental_options"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_jit"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_experimental_options"
argspec: "args=[\'options\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_jit"
argspec: "args=[\'enabled\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -8,6 +8,10 @@ tf_module {
name: "gpu" name: "gpu"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"
} }
member {
name: "optimizer"
mtype: "<type \'module\'>"
}
member { member {
name: "threading" name: "threading"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"

View File

@ -0,0 +1,19 @@
path: "tensorflow.config.optimizer"
tf_module {
member_method {
name: "get_experimental_options"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_jit"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_experimental_options"
argspec: "args=[\'options\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_jit"
argspec: "args=[\'enabled\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -8,6 +8,10 @@ tf_module {
name: "gpu" name: "gpu"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"
} }
member {
name: "optimizer"
mtype: "<type \'module\'>"
}
member { member {
name: "threading" name: "threading"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"