minor spelling tweaks
This commit is contained in:
parent
e1c7b48dbd
commit
bd8e308b4c
|
@ -23,7 +23,7 @@ from tensorflow.python.autograph.pyct import templates
|
||||||
|
|
||||||
|
|
||||||
class ConditionalExpressionTransformer(converter.Base):
|
class ConditionalExpressionTransformer(converter.Base):
|
||||||
"""Converts contitional expressions to functional form."""
|
"""Converts conditional expressions to functional form."""
|
||||||
|
|
||||||
def visit_IfExp(self, node):
|
def visit_IfExp(self, node):
|
||||||
return templates.replace_as_expression(
|
return templates.replace_as_expression(
|
||||||
|
|
|
@ -24,7 +24,7 @@ from tensorflow.python.autograph.core import converter
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.autograph.pyct import templates
|
from tensorflow.python.autograph.pyct import templates
|
||||||
|
|
||||||
# TODO(mdan): Properly extrack boolean ops according to lazy eval rules.
|
# TODO(mdan): Properly extract boolean ops according to lazy eval rules.
|
||||||
# Note that this isn't completely safe either, because tensors may have control
|
# Note that this isn't completely safe either, because tensors may have control
|
||||||
# dependencies.
|
# dependencies.
|
||||||
# Note that for loops that should be done after the loop was converted to
|
# Note that for loops that should be done after the loop was converted to
|
||||||
|
|
|
@ -39,7 +39,7 @@ class _RewriteBlock(object):
|
||||||
|
|
||||||
|
|
||||||
class ConditionalReturnRewriter(converter.Base):
|
class ConditionalReturnRewriter(converter.Base):
|
||||||
"""Rewrites a a pattern where it's unbovious that all paths return a value.
|
"""Rewrites a a pattern where it's unobvious that all paths return a value.
|
||||||
|
|
||||||
This rewrite allows avoiding intermediate None return values.
|
This rewrite allows avoiding intermediate None return values.
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ class SingleReturnTest(converter_testing.TestCase):
|
||||||
self.assertTransformedEquivalent(test_fn, 2)
|
self.assertTransformedEquivalent(test_fn, 2)
|
||||||
self.assertTransformedEquivalent(test_fn, -2)
|
self.assertTransformedEquivalent(test_fn, -2)
|
||||||
|
|
||||||
def test_contitional_missing_else(self):
|
def test_conditional_missing_else(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def test_fn(x):
|
||||||
if x > 0:
|
if x > 0:
|
||||||
|
|
|
@ -23,10 +23,10 @@ The class hierarchy is as follows:
|
||||||
[extends] converter.Base
|
[extends] converter.Base
|
||||||
[extends] transformer.Base
|
[extends] transformer.Base
|
||||||
[extends] gast.nodeTransformer
|
[extends] gast.nodeTransformer
|
||||||
[uses] transfomer.SourceInfo
|
[uses] transformer.SourceInfo
|
||||||
[uses] converter.EntityContext
|
[uses] converter.EntityContext
|
||||||
[uses] converter.ProgramContext
|
[uses] converter.ProgramContext
|
||||||
[uses] transfomer.SourceInfo
|
[uses] transformer.SourceInfo
|
||||||
|
|
||||||
converter.Base is a specialization of transformer.Base for AutoGraph. It's a
|
converter.Base is a specialization of transformer.Base for AutoGraph. It's a
|
||||||
very lightweight subclass that adds a `ctx` attribute holding the corresponding
|
very lightweight subclass that adds a `ctx` attribute holding the corresponding
|
||||||
|
|
|
@ -34,7 +34,7 @@ class FunctionScope(object):
|
||||||
* optional TF name scopes - these name scopes match the name of the
|
* optional TF name scopes - these name scopes match the name of the
|
||||||
function, for easy visualization in tensorBoard;
|
function, for easy visualization in tensorBoard;
|
||||||
* optional automatic control dependencies - this adds the same mechanism
|
* optional automatic control dependencies - this adds the same mechanism
|
||||||
for control dependenecies that is used by `@tf.function`; it can be
|
for control dependencies that is used by `@tf.function`; it can be
|
||||||
optionally enabled when using `tf.autograph.to_graph`;
|
optionally enabled when using `tf.autograph.to_graph`;
|
||||||
* tracking of autograph conversion state (whether it's enabled by the user,
|
* tracking of autograph conversion state (whether it's enabled by the user,
|
||||||
conversion options;
|
conversion options;
|
||||||
|
|
|
@ -39,7 +39,7 @@ class FunctionWrappersTest(test.TestCase):
|
||||||
t = constant_op.constant(1)
|
t = constant_op.constant(1)
|
||||||
self.assertIn('test_name', t.name)
|
self.assertIn('test_name', t.name)
|
||||||
|
|
||||||
def test_auto_cotrol_deps(self):
|
def test_auto_control_deps(self):
|
||||||
v = variables.Variable(1)
|
v = variables.Variable(1)
|
||||||
with function_wrappers.FunctionScope(
|
with function_wrappers.FunctionScope(
|
||||||
'_', None,
|
'_', None,
|
||||||
|
|
|
@ -30,7 +30,7 @@ class _NamingStyle(enum.Enum):
|
||||||
|
|
||||||
|
|
||||||
class Namer(object):
|
class Namer(object):
|
||||||
"""Symbol name generartor."""
|
"""Symbol name generator."""
|
||||||
|
|
||||||
def __init__(self, global_namespace):
|
def __init__(self, global_namespace):
|
||||||
self.global_namespace = global_namespace
|
self.global_namespace = global_namespace
|
||||||
|
|
|
@ -192,7 +192,7 @@ def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
|
||||||
|
|
||||||
# TODO(mdan): Grab features from context.
|
# TODO(mdan): Grab features from context.
|
||||||
# Note: we pass the original context through to convert to properly handle the
|
# Note: we pass the original context through to convert to properly handle the
|
||||||
# following scenario, which can be used insite TF implementations:
|
# following scenario, which can be used inside TF implementations:
|
||||||
#
|
#
|
||||||
# ctx = ag_ctx.control_status_ctx()
|
# ctx = ag_ctx.control_status_ctx()
|
||||||
# @function(autograph=False) # Low-level graph code
|
# @function(autograph=False) # Low-level graph code
|
||||||
|
|
|
@ -1021,7 +1021,7 @@ class ApiTest(test.TestCase):
|
||||||
ag_ctx.Status.ENABLED)
|
ag_ctx.Status.ENABLED)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Note: the autograph=False sets the contect to Status.DISABLED. The test
|
# Note: the autograph=False sets the connect to Status.DISABLED. The test
|
||||||
# verifies that to_graph overrides that.
|
# verifies that to_graph overrides that.
|
||||||
@def_function.function(autograph=False)
|
@def_function.function(autograph=False)
|
||||||
def f():
|
def f():
|
||||||
|
|
|
@ -167,7 +167,7 @@ class _UnboundInstanceCache(_FunctionCache):
|
||||||
|
|
||||||
|
|
||||||
# Using a re-entrant lock to guard against the unlikely possibility that the
|
# Using a re-entrant lock to guard against the unlikely possibility that the
|
||||||
# conversion process tiggers additional code execution.
|
# conversion process triggers additional code execution.
|
||||||
_CACHE_LOCK = threading.RLock()
|
_CACHE_LOCK = threading.RLock()
|
||||||
|
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name,
|
||||||
|
|
||||||
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
|
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
|
||||||
"""Returns a (possibly cached) factory for the converted result of entity."""
|
"""Returns a (possibly cached) factory for the converted result of entity."""
|
||||||
# The cache subkey encompases any conversion options on which the generated
|
# The cache subkey encompasses any conversion options on which the generated
|
||||||
# code may depend.
|
# code may depend.
|
||||||
# The cached factory includes the necessary definitions to distinguish
|
# The cached factory includes the necessary definitions to distinguish
|
||||||
# between the global and non-global free variables. For this reason, the
|
# between the global and non-global free variables. For this reason, the
|
||||||
|
|
|
@ -773,7 +773,7 @@ class _PythonLoopChecker(object):
|
||||||
self.check_op_count_after_iteration = False
|
self.check_op_count_after_iteration = False
|
||||||
self.ops_before_iteration = None
|
self.ops_before_iteration = None
|
||||||
|
|
||||||
def _verify_ineffcient_unroll(self):
|
def _verify_inefficient_unroll(self):
|
||||||
"""Checks for possibly-inefficient creation of ops in a Python loop."""
|
"""Checks for possibly-inefficient creation of ops in a Python loop."""
|
||||||
assert self.ops_before_iteration is not None
|
assert self.ops_before_iteration is not None
|
||||||
ops_after_iteration = self._get_ops()
|
ops_after_iteration = self._get_ops()
|
||||||
|
@ -810,7 +810,7 @@ class _PythonLoopChecker(object):
|
||||||
self._check_unroll_limits()
|
self._check_unroll_limits()
|
||||||
|
|
||||||
if self.check_op_count_after_iteration:
|
if self.check_op_count_after_iteration:
|
||||||
did_warn = self._verify_ineffcient_unroll()
|
did_warn = self._verify_inefficient_unroll()
|
||||||
if did_warn:
|
if did_warn:
|
||||||
self._stop_checking_inefficient_unroll() # Only warn once.
|
self._stop_checking_inefficient_unroll() # Only warn once.
|
||||||
elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3:
|
elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3:
|
||||||
|
|
|
@ -54,7 +54,7 @@ UNSPECIFIED = object()
|
||||||
|
|
||||||
def overload_of(f):
|
def overload_of(f):
|
||||||
if f in SUPPORTED_BUILTINS:
|
if f in SUPPORTED_BUILTINS:
|
||||||
return BUILTIN_FUINCTIONS_MAP[f.__name__]
|
return BUILTIN_FUNCTIONS_MAP[f.__name__]
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@ -441,7 +441,7 @@ def all_(iterable):
|
||||||
return _py_all(iterable)
|
return _py_all(iterable)
|
||||||
|
|
||||||
|
|
||||||
# all() operation is similiar to any() and could be translated
|
# all() operation is similar to any() and could be translated
|
||||||
# to `filter(False)` then `take(1)`, and check if `False` exists.
|
# to `filter(False)` then `take(1)`, and check if `False` exists.
|
||||||
def _tf_dataset_all(iterable):
|
def _tf_dataset_all(iterable):
|
||||||
# check and make sure iterable.element_spec only consists of one
|
# check and make sure iterable.element_spec only consists of one
|
||||||
|
@ -467,7 +467,7 @@ SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map,
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
SUPPORTED_BUILTINS += (xrange,)
|
SUPPORTED_BUILTINS += (xrange,)
|
||||||
|
|
||||||
BUILTIN_FUINCTIONS_MAP = {
|
BUILTIN_FUNCTIONS_MAP = {
|
||||||
'abs': abs_,
|
'abs': abs_,
|
||||||
'float': float_,
|
'float': float_,
|
||||||
'int': int_,
|
'int': int_,
|
||||||
|
|
|
@ -248,7 +248,7 @@ class GraphBuilder(object):
|
||||||
This builder ignores the flow generated by exceptions, which are assumed to
|
This builder ignores the flow generated by exceptions, which are assumed to
|
||||||
always be catastrophic and present purely for diagnostic purposes (e.g. to
|
always be catastrophic and present purely for diagnostic purposes (e.g. to
|
||||||
print debug information). Statements like raise and try/catch sections are
|
print debug information). Statements like raise and try/catch sections are
|
||||||
allowed and will generate control flow edges, but ordinaty statements are
|
allowed and will generate control flow edges, but ordinary statements are
|
||||||
assumed not to raise exceptions.
|
assumed not to raise exceptions.
|
||||||
|
|
||||||
Finally sections are also correctly interleaved between break/continue/return
|
Finally sections are also correctly interleaved between break/continue/return
|
||||||
|
|
|
@ -47,7 +47,7 @@ class DummyGensym(object):
|
||||||
|
|
||||||
# These two test functions have to be top-level, not nested, for compatibility
|
# These two test functions have to be top-level, not nested, for compatibility
|
||||||
# with some unknown version of Python 2.7 preceding 2.7.15. Why? Because
|
# with some unknown version of Python 2.7 preceding 2.7.15. Why? Because
|
||||||
# `exec` and nested function definitions _incomaptibly_ change the
|
# `exec` and nested function definitions _incompatibly_ change the
|
||||||
# representation of local variables, such that `exec` inside a nested function
|
# representation of local variables, such that `exec` inside a nested function
|
||||||
# definition is a syntax error in that version. The tuple form of `exec` fixes
|
# definition is a syntax error in that version. The tuple form of `exec` fixes
|
||||||
# this problem, but apparently that was introduced in some unknown version of
|
# this problem, but apparently that was introduced in some unknown version of
|
||||||
|
@ -465,7 +465,7 @@ class AnfNonTransformationTest(AnfTransformerTest):
|
||||||
node, _ = parser.parse_entity(test_fn, future_features=())
|
node, _ = parser.parse_entity(test_fn, future_features=())
|
||||||
orig_source = parser.unparse(node, indentation=' ')
|
orig_source = parser.unparse(node, indentation=' ')
|
||||||
orig_str = textwrap.dedent(orig_source).strip()
|
orig_str = textwrap.dedent(orig_source).strip()
|
||||||
config = [(anf.ANY, anf.LEAVE)] # Configuration to trasform nothing
|
config = [(anf.ANY, anf.LEAVE)] # Configuration to transform nothing
|
||||||
node = anf.transform(
|
node = anf.transform(
|
||||||
node, self._simple_context(),
|
node, self._simple_context(),
|
||||||
config=config, gensym_source=DummyGensym)
|
config=config, gensym_source=DummyGensym)
|
||||||
|
|
|
@ -252,7 +252,7 @@ def resolve(node, source, context_filepath, context_lineno, context_col_offset):
|
||||||
|
|
||||||
|
|
||||||
def resolve_entity(node, source, entity):
|
def resolve_entity(node, source, entity):
|
||||||
"""Like resolve, but extracts the context informartion from an entity."""
|
"""Like resolve, but extracts the context information from an entity."""
|
||||||
lines, lineno = tf_inspect.getsourcelines(entity)
|
lines, lineno = tf_inspect.getsourcelines(entity)
|
||||||
filepath = tf_inspect.getsourcefile(entity)
|
filepath = tf_inspect.getsourcefile(entity)
|
||||||
|
|
||||||
|
|
|
@ -91,7 +91,7 @@ def dedent_block(code_string):
|
||||||
# See:
|
# See:
|
||||||
# https://docs.python.org/3/reference/lexical_analysis.html#indentation
|
# https://docs.python.org/3/reference/lexical_analysis.html#indentation
|
||||||
raise errors.UnsupportedLanguageElementError(
|
raise errors.UnsupportedLanguageElementError(
|
||||||
'code mixing tabs and spaces for intentation is not allowed')
|
'code mixing tabs and spaces for indentation is not allowed')
|
||||||
if len(tok_string) >= block_level:
|
if len(tok_string) >= block_level:
|
||||||
tok_string = tok_string[block_level:]
|
tok_string = tok_string[block_level:]
|
||||||
tokens[i] = (tok_type, tok_string)
|
tokens[i] = (tok_type, tok_string)
|
||||||
|
|
|
@ -44,7 +44,7 @@ class Scope(object):
|
||||||
|
|
||||||
Scope objects are mutable during construction only, and must be frozen using
|
Scope objects are mutable during construction only, and must be frozen using
|
||||||
`Scope.finalize()` before use. Furthermore, a scope is consistent only after
|
`Scope.finalize()` before use. Furthermore, a scope is consistent only after
|
||||||
all its chiledren have been frozen. While analysing code blocks, scopes are
|
all its children have been frozen. While analysing code blocks, scopes are
|
||||||
being gradually built, from the innermost scope outward. Freezing indicates
|
being gradually built, from the innermost scope outward. Freezing indicates
|
||||||
that the analysis of a code block is complete. Once frozen, mutation is no
|
that the analysis of a code block is complete. Once frozen, mutation is no
|
||||||
longer allowed. `is_final` tracks whether the scope is frozen or not. Certain
|
longer allowed. `is_final` tracks whether the scope is frozen or not. Certain
|
||||||
|
|
|
@ -234,7 +234,7 @@ class TreeAnnotator(transformer.Base):
|
||||||
# Recursively process any remaining subfunctions.
|
# Recursively process any remaining subfunctions.
|
||||||
self.current_analyzer = analyzer
|
self.current_analyzer = analyzer
|
||||||
# Note: not visiting name, decorator_list and returns because they don't
|
# Note: not visiting name, decorator_list and returns because they don't
|
||||||
# apply to this anlysis.
|
# apply to this analysis.
|
||||||
# TODO(mdan): Should we still process the function name?
|
# TODO(mdan): Should we still process the function name?
|
||||||
node.args = self.visit(node.args)
|
node.args = self.visit(node.args)
|
||||||
node.body = self.visit_block(node.body)
|
node.body = self.visit_block(node.body)
|
||||||
|
|
|
@ -253,7 +253,7 @@ class Base(gast.NodeTransformer):
|
||||||
self.enter_local_scope()
|
self.enter_local_scope()
|
||||||
|
|
||||||
# Allows scoping of local variables to keep state across calls to visit_*
|
# Allows scoping of local variables to keep state across calls to visit_*
|
||||||
# methods. Multiple scope hierchies may exist and are keyed by tag. A scope
|
# methods. Multiple scope hierarchies may exist and are keyed by tag. A scope
|
||||||
# is valid at one or more nodes and all its children. Scopes created in
|
# is valid at one or more nodes and all its children. Scopes created in
|
||||||
# child nodes supersede their parent. Scopes are isolated from one another.
|
# child nodes supersede their parent. Scopes are isolated from one another.
|
||||||
self.state = _State()
|
self.state = _State()
|
||||||
|
|
|
@ -97,7 +97,7 @@ class VirtualGpuTestUtil(object):
|
||||||
|
|
||||||
# Generates a list of 3-tuples, each tuple contains the source and destination
|
# Generates a list of 3-tuples, each tuple contains the source and destination
|
||||||
# device index for a binary operation like 'add', like:
|
# device index for a binary operation like 'add', like:
|
||||||
# (src_devcie_1, src_device_2, dst_device)
|
# (src_device_1, src_device_2, dst_device)
|
||||||
def _GenerateOperationPlacement(self):
|
def _GenerateOperationPlacement(self):
|
||||||
result = []
|
result = []
|
||||||
for unused_i in range(self._num_ops):
|
for unused_i in range(self._num_ops):
|
||||||
|
|
|
@ -96,7 +96,7 @@ def forward_compatible(year, month, day):
|
||||||
if compat.forward_compatible(year, month, day):
|
if compat.forward_compatible(year, month, day):
|
||||||
# Can use the awesome new implementation.
|
# Can use the awesome new implementation.
|
||||||
return gen_math_ops.my_new_awesome_add(inputs, name)
|
return gen_math_ops.my_new_awesome_add(inputs, name)
|
||||||
# To maintain forward compatibiltiy, use the old implementation.
|
# To maintain forward compatibility, use the old implementation.
|
||||||
return gen_math_ops.add(inputs, name)
|
return gen_math_ops.add(inputs, name)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for forward and backwards compatibility utilties."""
|
"""Tests for forward and backwards compatibility utilities."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for forward and backwards compatibility utilties."""
|
"""Tests for forward and backwards compatibility utilities."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
|
@ -114,7 +114,7 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
|
||||||
run_params=run_params,
|
run_params=run_params,
|
||||||
conversion_params=conversion_params,
|
conversion_params=conversion_params,
|
||||||
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
||||||
# format to NCHW format under four dimentional input.
|
# format to NCHW format under four dimensional input.
|
||||||
disable_non_trt_optimizers=True)
|
disable_non_trt_optimizers=True)
|
||||||
return conversion_params._replace(
|
return conversion_params._replace(
|
||||||
rewriter_config_template=rewrite_config_with_trt)
|
rewriter_config_template=rewrite_config_with_trt)
|
||||||
|
|
|
@ -89,7 +89,7 @@ class DynamicInputShapesTest(trt_test.TfTrtIntegrationTestBase):
|
||||||
run_params=run_params,
|
run_params=run_params,
|
||||||
conversion_params=conversion_params,
|
conversion_params=conversion_params,
|
||||||
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
||||||
# format to NCHW format under four dimentional input.
|
# format to NCHW format under four dimensional input.
|
||||||
disable_non_trt_optimizers=True)
|
disable_non_trt_optimizers=True)
|
||||||
return conversion_params._replace(
|
return conversion_params._replace(
|
||||||
rewriter_config_template=rewrite_config_with_trt)
|
rewriter_config_template=rewrite_config_with_trt)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""This test checks a situation where the same tensor is considered as an output
|
"""This test checks a situation where the same tensor is considered as an output
|
||||||
|
|
||||||
multiple times because it has been duplicated by 2+ indentity ops. Previously,
|
multiple times because it has been duplicated by 2+ identity ops. Previously,
|
||||||
the tensor would be renamed multiple times, overwriting the output binding name
|
the tensor would be renamed multiple times, overwriting the output binding name
|
||||||
which resulted in a runtime error when the binding would not be found.
|
which resulted in a runtime error when the binding would not be found.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -55,7 +55,7 @@ class ExcludeUnsupportedInt32Test(trt_test.TfTrtIntegrationTestBase):
|
||||||
run_params=run_params,
|
run_params=run_params,
|
||||||
conversion_params=conversion_params,
|
conversion_params=conversion_params,
|
||||||
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
# Disable layout optimizer, since it will convert BiasAdd with NHWC
|
||||||
# format to NCHW format under four dimentional input.
|
# format to NCHW format under four dimensional input.
|
||||||
disable_non_trt_optimizers=True)
|
disable_non_trt_optimizers=True)
|
||||||
return conversion_params._replace(
|
return conversion_params._replace(
|
||||||
rewriter_config_template=rewrite_config_with_trt)
|
rewriter_config_template=rewrite_config_with_trt)
|
||||||
|
|
|
@ -145,7 +145,7 @@ class TrtConversionParams(collections.namedtuple("TrtConversionParams", [
|
||||||
missing ranges. The calibration graph must be converted to an inference
|
missing ranges. The calibration graph must be converted to an inference
|
||||||
graph by running calibration with calibrate(). If set to False,
|
graph by running calibration with calibrate(). If set to False,
|
||||||
quantization nodes will be expected for every tensor in the graph
|
quantization nodes will be expected for every tensor in the graph
|
||||||
(exlcuding those which will be fused). If a range is missing, an error
|
(excluding those which will be fused). If a range is missing, an error
|
||||||
will occur. Please note that accuracy may be negatively affected if
|
will occur. Please note that accuracy may be negatively affected if
|
||||||
there is a mismatch between which tensors TRT quantizes and which
|
there is a mismatch between which tensors TRT quantizes and which
|
||||||
tensors were trained with fake quantization.
|
tensors were trained with fake quantization.
|
||||||
|
@ -207,7 +207,7 @@ def _check_conversion_params(conversion_params, is_v2=False):
|
||||||
"Found more than one TensorRTOptimizer in "
|
"Found more than one TensorRTOptimizer in "
|
||||||
"rewriter_config_template while only one is allowed.")
|
"rewriter_config_template while only one is allowed.")
|
||||||
trt_optimizer = optimizer
|
trt_optimizer = optimizer
|
||||||
# If rewriter_config_template is set, it should inculde TensorRTOptimizer.
|
# If rewriter_config_template is set, it should include TensorRTOptimizer.
|
||||||
# It is possible to remove this requirement if needed.
|
# It is possible to remove this requirement if needed.
|
||||||
if not trt_optimizer:
|
if not trt_optimizer:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -327,7 +327,7 @@ def get_tensorrt_rewriter_config(conversion_params,
|
||||||
rewriter_config_with_trt.CopyFrom(
|
rewriter_config_with_trt.CopyFrom(
|
||||||
conversion_params.rewriter_config_template)
|
conversion_params.rewriter_config_template)
|
||||||
|
|
||||||
# Disabling optimizers should happen after CopyFrom the temaplte
|
# Disabling optimizers should happen after CopyFrom the template
|
||||||
# otherwise the template can overwrite the disablement.
|
# otherwise the template can overwrite the disablement.
|
||||||
if disable_non_trt_optimizers:
|
if disable_non_trt_optimizers:
|
||||||
off = rewriter_config_pb2.RewriterConfig.OFF
|
off = rewriter_config_pb2.RewriterConfig.OFF
|
||||||
|
@ -443,7 +443,7 @@ class TrtGraphConverter(object):
|
||||||
missing ranges. The calibration graph must be converted to an inference
|
missing ranges. The calibration graph must be converted to an inference
|
||||||
graph by running calibration with calibrate(). If set to False,
|
graph by running calibration with calibrate(). If set to False,
|
||||||
quantization nodes will be expected for every tensor in the graph
|
quantization nodes will be expected for every tensor in the graph
|
||||||
(exlcuding those which will be fused). If a range is missing, an error
|
(excluding those which will be fused). If a range is missing, an error
|
||||||
will occur. Please note that accuracy may be negatively affected if
|
will occur. Please note that accuracy may be negatively affected if
|
||||||
there is a mismatch between which tensors TRT quantizes and which
|
there is a mismatch between which tensors TRT quantizes and which
|
||||||
tensors were trained with fake quantization.
|
tensors were trained with fake quantization.
|
||||||
|
|
|
@ -500,7 +500,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
# Load and verify the converted model.
|
# Load and verify the converted model.
|
||||||
#
|
#
|
||||||
# TODO(laigd): the name of the new input_signature of the
|
# TODO(laigd): the name of the new input_signature of the
|
||||||
# `root_with_trt.run` function is empty string (originaly was None),
|
# `root_with_trt.run` function is empty string (originally was None),
|
||||||
# investigate why.
|
# investigate why.
|
||||||
root_with_trt = load.load(output_saved_model_dir)
|
root_with_trt = load.load(output_saved_model_dir)
|
||||||
# TODO(laigd): `root_with_trt.run` is still using the original graph without
|
# TODO(laigd): `root_with_trt.run` is still using the original graph without
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Bechmarks for `tf.data.Dataset.map()`."""
|
"""Benchmarks for `tf.data.Dataset.map()`."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
|
@ -440,7 +440,7 @@ def make_csv_dataset_v2(
|
||||||
if compression_type is not None:
|
if compression_type is not None:
|
||||||
compression_type_value = tensor_util.constant_value(compression_type)
|
compression_type_value = tensor_util.constant_value(compression_type)
|
||||||
if compression_type_value is None:
|
if compression_type_value is None:
|
||||||
raise ValueError("Received unkown compression_type")
|
raise ValueError("Received unknown compression_type")
|
||||||
if compression_type_value == "GZIP":
|
if compression_type_value == "GZIP":
|
||||||
file_io_fn = lambda filename: gzip.open(filename, "rt")
|
file_io_fn = lambda filename: gzip.open(filename, "rt")
|
||||||
elif compression_type_value == "ZLIB":
|
elif compression_type_value == "ZLIB":
|
||||||
|
|
|
@ -122,7 +122,7 @@ class FromGeneratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
# The interleave transformation is essentially a flat map that
|
# The interleave transformation is essentially a flat map that
|
||||||
# draws from multiple input datasets concurrently (in a cyclic
|
# draws from multiple input datasets concurrently (in a cyclic
|
||||||
# fashion). By placing `Datsaet.from_generator()` inside an
|
# fashion). By placing `Dataset.from_generator()` inside an
|
||||||
# interleave, we test its behavior when multiple iterators are
|
# interleave, we test its behavior when multiple iterators are
|
||||||
# active at the same time; by additionally prefetching inside the
|
# active at the same time; by additionally prefetching inside the
|
||||||
# interleave, we create the possibility of parallel (modulo GIL)
|
# interleave, we create the possibility of parallel (modulo GIL)
|
||||||
|
|
|
@ -2526,7 +2526,7 @@ def get_structure(dataset_or_iterator):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||||
element of `dataset_or_iterator` and spacifying the type of individal
|
element of `dataset_or_iterator` and specifying the type of individual
|
||||||
components.
|
components.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -3013,7 +3013,7 @@ class StructuredFunctionWrapper(object):
|
||||||
use_legacy_function: (Optional.) A boolean that determines whether the
|
use_legacy_function: (Optional.) A boolean that determines whether the
|
||||||
function be created using `tensorflow.python.eager.function.defun`
|
function be created using `tensorflow.python.eager.function.defun`
|
||||||
(default behavior) or `tensorflow.python.framework.function.Defun`
|
(default behavior) or `tensorflow.python.framework.function.Defun`
|
||||||
(legacy beheavior).
|
(legacy behavior).
|
||||||
defun_kwargs: (Optional.) A dictionary mapping string argument names to
|
defun_kwargs: (Optional.) A dictionary mapping string argument names to
|
||||||
values. If supplied, will be passed to `function` as keyword arguments.
|
values. If supplied, will be passed to `function` as keyword arguments.
|
||||||
|
|
||||||
|
@ -3049,7 +3049,7 @@ class StructuredFunctionWrapper(object):
|
||||||
# There is no graph to add in eager mode.
|
# There is no graph to add in eager mode.
|
||||||
add_to_graph &= not context.executing_eagerly()
|
add_to_graph &= not context.executing_eagerly()
|
||||||
# There are some lifetime issues when a legacy function is not added to a
|
# There are some lifetime issues when a legacy function is not added to a
|
||||||
# out-living graph. It's already deprecated so de-priotizing the fix.
|
# out-living graph. It's already deprecated so de-prioritizing the fix.
|
||||||
add_to_graph |= use_legacy_function
|
add_to_graph |= use_legacy_function
|
||||||
|
|
||||||
if defun_kwargs is None:
|
if defun_kwargs is None:
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Python dataset sparse tensor utility functitons."""
|
"""Python dataset sparse tensor utility functions."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
|
@ -398,7 +398,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||||
# Define three mutually incompatible values/structures, and assert that:
|
# Define three mutually incompatible values/structures, and assert that:
|
||||||
# 1. Using one structure to flatten a value with an incompatible structure
|
# 1. Using one structure to flatten a value with an incompatible structure
|
||||||
# fails.
|
# fails.
|
||||||
# 2. Using one structure to restructre a flattened value with an
|
# 2. Using one structure to restructure a flattened value with an
|
||||||
# incompatible structure fails.
|
# incompatible structure fails.
|
||||||
value_tensor = constant_op.constant(42.0)
|
value_tensor = constant_op.constant(42.0)
|
||||||
s_tensor = structure.type_spec_from_value(value_tensor)
|
s_tensor = structure.type_spec_from_value(value_tensor)
|
||||||
|
|
|
@ -1214,7 +1214,7 @@ class CursesUI(base_ui.BaseUI):
|
||||||
"""Pad the whitespace at the end of a line with the default color pair.
|
"""Pad the whitespace at the end of a line with the default color pair.
|
||||||
|
|
||||||
Prevents spurious color pairs from appearing at the end of the lines in
|
Prevents spurious color pairs from appearing at the end of the lines in
|
||||||
certain text terimnals.
|
certain text terminals.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pad: The curses pad object to operate on.
|
pad: The curses pad object to operate on.
|
||||||
|
|
|
@ -115,7 +115,7 @@ def rich_text_lines_from_rich_line_list(rich_text_list, annotations=None):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rich_text_list: a list of RichLine objects or strings
|
rich_text_list: a list of RichLine objects or strings
|
||||||
annotations: annotatoins for the resultant RichTextLines object.
|
annotations: annotations for the resultant RichTextLines object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A corresponding RichTextLines object.
|
A corresponding RichTextLines object.
|
||||||
|
|
|
@ -476,7 +476,7 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase):
|
||||||
help_lines = registry.get_help().lines
|
help_lines = registry.get_help().lines
|
||||||
|
|
||||||
# The help info should list commands in alphabetically sorted order,
|
# The help info should list commands in alphabetically sorted order,
|
||||||
# regardless of order in which the commands are reigstered.
|
# regardless of order in which the commands are registered.
|
||||||
self.assertEqual("cols", help_lines[0])
|
self.assertEqual("cols", help_lines[0])
|
||||||
self.assertTrue(help_lines[1].endswith("Aliases: c"))
|
self.assertTrue(help_lines[1].endswith("Aliases: c"))
|
||||||
self.assertFalse(help_lines[2])
|
self.assertFalse(help_lines[2])
|
||||||
|
@ -790,7 +790,7 @@ class SliceRichTextLinesTest(test_util.TensorFlowTestCase):
|
||||||
self.assertEqual(["Roses are red"], sliced.lines)
|
self.assertEqual(["Roses are red"], sliced.lines)
|
||||||
self.assertEqual({0: [(0, 5, "red")]}, sliced.font_attr_segs)
|
self.assertEqual({0: [(0, 5, "red")]}, sliced.font_attr_segs)
|
||||||
|
|
||||||
# Non-line-number metadata should be preseved.
|
# Non-line-number metadata should be preserved.
|
||||||
self.assertEqual({
|
self.assertEqual({
|
||||||
0: "longer wavelength",
|
0: "longer wavelength",
|
||||||
"foo_metadata": "bar"
|
"foo_metadata": "bar"
|
||||||
|
@ -1024,7 +1024,7 @@ class CommandHistoryTest(test_util.TensorFlowTestCase):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
["help 2\n", "help 3\n", "help 4\n"], f.readlines())
|
["help 2\n", "help 3\n", "help 4\n"], f.readlines())
|
||||||
|
|
||||||
def testCommandHistoryHandlesReadingIOErrorGracoiusly(self):
|
def testCommandHistoryHandlesReadingIOErrorGraciously(self):
|
||||||
with open(self._history_file_path, "wt") as f:
|
with open(self._history_file_path, "wt") as f:
|
||||||
f.write("help\n")
|
f.write("help\n")
|
||||||
|
|
||||||
|
@ -1037,7 +1037,7 @@ class CommandHistoryTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
self._restoreFileReadWritePermissions(self._history_file_path)
|
self._restoreFileReadWritePermissions(self._history_file_path)
|
||||||
|
|
||||||
def testCommandHistoryHandlesWritingIOErrorGracoiusly(self):
|
def testCommandHistoryHandlesWritingIOErrorGraciously(self):
|
||||||
with open(self._history_file_path, "wt") as f:
|
with open(self._history_file_path, "wt") as f:
|
||||||
f.write("help\n")
|
f.write("help\n")
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ def _parse_debug_tensor_name(debug_tensor_name):
|
||||||
`None`.
|
`None`.
|
||||||
node_name: Name of the node.
|
node_name: Name of the node.
|
||||||
output_slot: Output slot index as an `int`.
|
output_slot: Output slot index as an `int`.
|
||||||
debug_op: If the debug op suffix exists, the debug op name; otheriwse,
|
debug_op: If the debug op suffix exists, the debug op name; otherwise,
|
||||||
`None`.
|
`None`.
|
||||||
exec_index: Execution index (applicable to cases in which a debug tensor
|
exec_index: Execution index (applicable to cases in which a debug tensor
|
||||||
is computed multiple times in a `tf.Session.run` call, e.g., due to
|
is computed multiple times in a `tf.Session.run` call, e.g., due to
|
||||||
|
|
|
@ -273,11 +273,11 @@ class ProfileAnalyzerListProfileTest(test_util.TensorFlowTestCase):
|
||||||
prof_output = prof_analyzer.list_profile(["-f", ".*file2"]).lines
|
prof_output = prof_analyzer.list_profile(["-f", ".*file2"]).lines
|
||||||
_assert_at_least_one_line_matches(r"Add/123", prof_output)
|
_assert_at_least_one_line_matches(r"Add/123", prof_output)
|
||||||
_assert_no_lines_match(r"Mul/456", prof_output)
|
_assert_no_lines_match(r"Mul/456", prof_output)
|
||||||
# Fitler by execution time.
|
# Filter by execution time.
|
||||||
prof_output = prof_analyzer.list_profile(["-e", "[5, 10]"]).lines
|
prof_output = prof_analyzer.list_profile(["-e", "[5, 10]"]).lines
|
||||||
_assert_at_least_one_line_matches(r"Mul/456", prof_output)
|
_assert_at_least_one_line_matches(r"Mul/456", prof_output)
|
||||||
_assert_no_lines_match(r"Add/123", prof_output)
|
_assert_no_lines_match(r"Add/123", prof_output)
|
||||||
# Fitler by op time.
|
# Filter by op time.
|
||||||
prof_output = prof_analyzer.list_profile(["-o", ">=2"]).lines
|
prof_output = prof_analyzer.list_profile(["-o", ">=2"]).lines
|
||||||
_assert_at_least_one_line_matches(r"Add/123", prof_output)
|
_assert_at_least_one_line_matches(r"Add/123", prof_output)
|
||||||
_assert_no_lines_match(r"Mul/456", prof_output)
|
_assert_no_lines_match(r"Mul/456", prof_output)
|
||||||
|
|
|
@ -78,7 +78,7 @@ class RichTextLinesTest(test_util.TensorFlowTestCase):
|
||||||
def _checkTensorElementLocations(self, out, a):
|
def _checkTensorElementLocations(self, out, a):
|
||||||
"""Check the results of locate_tensor_element on an ndarray representation.
|
"""Check the results of locate_tensor_element on an ndarray representation.
|
||||||
|
|
||||||
that represents a numpy.ndaray.
|
that represents a numpy.ndarray.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
out: An instance of RichTextLines representing a numpy.ndarray.
|
out: An instance of RichTextLines representing a numpy.ndarray.
|
||||||
|
|
|
@ -426,7 +426,7 @@ def disable_check_numerics():
|
||||||
"""Disable the eager/graph unified numerics checking mechanism.
|
"""Disable the eager/graph unified numerics checking mechanism.
|
||||||
|
|
||||||
This method can be used after a call to `tf.debugging.enable_check_numerics()`
|
This method can be used after a call to `tf.debugging.enable_check_numerics()`
|
||||||
to disable the numerics-checking mechanism that catches inifnity and NaN
|
to disable the numerics-checking mechanism that catches infinity and NaN
|
||||||
values output by ops executed eagerly or in tf.function-compiled graphs.
|
values output by ops executed eagerly or in tf.function-compiled graphs.
|
||||||
|
|
||||||
This method is idempotent. Calling it multiple times has the same effect
|
This method is idempotent. Calling it multiple times has the same effect
|
||||||
|
|
|
@ -430,7 +430,7 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
|
||||||
self.assertIn("one_over_x = 1.0 / x", message)
|
self.assertIn("one_over_x = 1.0 / x", message)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testInfInCustomKerasLayerWithoutTfFuntionPredictCall(self):
|
def testInfInCustomKerasLayerWithoutTfFunctionPredictCall(self):
|
||||||
"""Test catching Infinity in a custom layer, w/o tf.function."""
|
"""Test catching Infinity in a custom layer, w/o tf.function."""
|
||||||
check_numerics_callback.enable_check_numerics()
|
check_numerics_callback.enable_check_numerics()
|
||||||
|
|
||||||
|
@ -483,7 +483,7 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
|
||||||
check_numerics_callback.enable_check_numerics()
|
check_numerics_callback.enable_check_numerics()
|
||||||
|
|
||||||
def generate_nan(x):
|
def generate_nan(x):
|
||||||
"""Intetionally generates NaNs by taking log of negative number."""
|
"""Intentionally generates NaNs by taking log of negative number."""
|
||||||
casted_x = math_ops.cast(x, dtypes.float32)
|
casted_x = math_ops.cast(x, dtypes.float32)
|
||||||
return math_ops.log([[-1.0, 1.0], [3.0, 5.0]]) + casted_x
|
return math_ops.log([[-1.0, 1.0], [3.0, 5.0]]) + casted_x
|
||||||
|
|
||||||
|
@ -503,7 +503,7 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
|
||||||
message)
|
message)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testCustomGradietWithNaNWithTfFunction(self):
|
def testCustomGradientWithNaNWithTfFunction(self):
|
||||||
"""Test that callback catches NaN in a gradient function during backprop."""
|
"""Test that callback catches NaN in a gradient function during backprop."""
|
||||||
check_numerics_callback.enable_check_numerics()
|
check_numerics_callback.enable_check_numerics()
|
||||||
|
|
||||||
|
|
|
@ -296,7 +296,7 @@ class DebugTensorDatum(object):
|
||||||
directory is `/tmp/tfdbg_1` and the dump file is at
|
directory is `/tmp/tfdbg_1` and the dump file is at
|
||||||
`/tmp/tfdbg_1/<device_path>/>ns_1/node_a_0_DebugIdentity_123456789`,
|
`/tmp/tfdbg_1/<device_path>/>ns_1/node_a_0_DebugIdentity_123456789`,
|
||||||
then the value of the debug_dump_rel_path should be
|
then the value of the debug_dump_rel_path should be
|
||||||
`<device_path>/ns_1/node_a_0_DebugIdenity_1234456789`.
|
`<device_path>/ns_1/node_a_0_DebugIdentity_1234456789`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the base file name of the dump file does not conform to
|
ValueError: If the base file name of the dump file does not conform to
|
||||||
|
|
|
@ -116,7 +116,7 @@ class GradientsDebugger(object):
|
||||||
|
|
||||||
The side effect of this method is that when gradient tensor(s) are created
|
The side effect of this method is that when gradient tensor(s) are created
|
||||||
with respect to the any paths that include the `input_tensor`, the gradient
|
with respect to the any paths that include the `input_tensor`, the gradient
|
||||||
tensor(s) with repsect to `input_tensor` will be registered with this
|
tensor(s) with respect to `input_tensor` will be registered with this
|
||||||
this `GradientsDebugger` instance and can later be retrieved, with the
|
this `GradientsDebugger` instance and can later be retrieved, with the
|
||||||
methods `gradient_tensor` and `gradient_tensors`.
|
methods `gradient_tensor` and `gradient_tensors`.
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ class GradientsDebugger(object):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensor: the input `tf.Tensor` object whose related gradient tensors
|
input_tensor: the input `tf.Tensor` object whose related gradient tensors
|
||||||
are to be reigstered with this `GradientsDebugger` instance when they
|
are to be registered with this `GradientsDebugger` instance when they
|
||||||
are created, e.g., during `tf.gradients` calls or the construction
|
are created, e.g., during `tf.gradients` calls or the construction
|
||||||
of optimization (training) op that uses `tf.gradients`.
|
of optimization (training) op that uses `tf.gradients`.
|
||||||
|
|
||||||
|
@ -173,7 +173,7 @@ class GradientsDebugger(object):
|
||||||
|
|
||||||
The side effect of this method is that when gradient tensor(s) are created
|
The side effect of this method is that when gradient tensor(s) are created
|
||||||
with respect to the any paths that include the `x_tensor`s, the gradient
|
with respect to the any paths that include the `x_tensor`s, the gradient
|
||||||
tensor(s) with repsect to the tensor will be registered with this
|
tensor(s) with respect to the tensor will be registered with this
|
||||||
this `GradientsDebugger` instance and can later be retrieved, with the
|
this `GradientsDebugger` instance and can later be retrieved, with the
|
||||||
methods `gradient_tensor` and `gradient_tensors`.
|
methods `gradient_tensor` and `gradient_tensors`.
|
||||||
|
|
||||||
|
|
|
@ -144,7 +144,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||||
self._compareOriginalAndReconstructedGraphDefs(
|
self._compareOriginalAndReconstructedGraphDefs(
|
||||||
sess, c, expected_output=400.0)
|
sess, c, expected_output=400.0)
|
||||||
|
|
||||||
def testReonstructGraphWithCond(self):
|
def testReconstructGraphWithCond(self):
|
||||||
with session.Session(config=self._no_rewrite_session_config()) as sess:
|
with session.Session(config=self._no_rewrite_session_config()) as sess:
|
||||||
x = variables.Variable(10.0, name="x")
|
x = variables.Variable(10.0, name="x")
|
||||||
y = variables.Variable(20.0, name="y")
|
y = variables.Variable(20.0, name="y")
|
||||||
|
|
|
@ -658,7 +658,7 @@ class DebugIdentityV2OpTest(dumping_callback_test_lib.DumpingCallbackTestBase):
|
||||||
tensor_id=x._id,
|
tensor_id=x._id,
|
||||||
output_dtype=dtypes.float64)), x._id
|
output_dtype=dtypes.float64)), x._id
|
||||||
|
|
||||||
# Assert the same op is returns a consistant value
|
# Assert the same op is returns a consistent value
|
||||||
x = np.zeros([100, 100], dtype=np.float16)
|
x = np.zeros([100, 100], dtype=np.float16)
|
||||||
x[32, 47] = np.nan
|
x[32, 47] = np.nan
|
||||||
x[0:4, 3] = np.inf
|
x[0:4, 3] = np.inf
|
||||||
|
|
|
@ -67,7 +67,7 @@ def _concrete_tensor_to_proto(tensor):
|
||||||
|
|
||||||
|
|
||||||
class _DumpingCallback(object):
|
class _DumpingCallback(object):
|
||||||
"""An object holding the states surrouding the dumping callback."""
|
"""An object holding the states surrounding the dumping callback."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dump_root,
|
dump_root,
|
||||||
|
|
|
@ -1352,7 +1352,7 @@ class TracingCallbackTest(
|
||||||
("FullTensor", "FULL_TENSOR"),
|
("FullTensor", "FULL_TENSOR"),
|
||||||
)
|
)
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testMobiletNetV2Fit(self, tensor_debug_mode):
|
def testMobileNetV2Fit(self, tensor_debug_mode):
|
||||||
"""Test training Keras MobileNetV2 works with dumping."""
|
"""Test training Keras MobileNetV2 works with dumping."""
|
||||||
# Use a large circular-buffer to make sure we capture all the executed ops.
|
# Use a large circular-buffer to make sure we capture all the executed ops.
|
||||||
writer = dumping_callback.enable_dump_debug_info(
|
writer = dumping_callback.enable_dump_debug_info(
|
||||||
|
|
|
@ -574,7 +574,7 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
|
||||||
if i in (0, 2):
|
if i in (0, 2):
|
||||||
# During runs 0 and 2, the server should have received the published
|
# During runs 0 and 2, the server should have received the published
|
||||||
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
|
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
|
||||||
# unblocked by EventReply reponses from the server.
|
# unblocked by EventReply responses from the server.
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
[5.0],
|
[5.0],
|
||||||
self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
|
self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
|
||||||
|
@ -628,7 +628,7 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
|
||||||
if i in (0, 2):
|
if i in (0, 2):
|
||||||
# During runs 0 and 2, the server should have received the published
|
# During runs 0 and 2, the server should have received the published
|
||||||
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
|
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
|
||||||
# unblocked by EventReply reponses from the server.
|
# unblocked by EventReply responses from the server.
|
||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
[5.0],
|
[5.0],
|
||||||
self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
|
self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
|
||||||
|
|
|
@ -116,7 +116,7 @@ def _send_call_tracebacks(destinations,
|
||||||
origin_stack: The traceback stack for the origin of the execution call. For
|
origin_stack: The traceback stack for the origin of the execution call. For
|
||||||
graph execution, this is the traceback of the `tf.Session.run()`
|
graph execution, this is the traceback of the `tf.Session.run()`
|
||||||
invocation. For eager execution, this is the traceback of the Python
|
invocation. For eager execution, this is the traceback of the Python
|
||||||
line that executes the eager opertion.
|
line that executes the eager operation.
|
||||||
is_eager_execution: (`bool`) whether an eager execution call (i.e., not a
|
is_eager_execution: (`bool`) whether an eager execution call (i.e., not a
|
||||||
`tf.Session.run` or derived methods) is being sent.
|
`tf.Session.run` or derived methods) is being sent.
|
||||||
call_key: The key of the execution call, as a string. For graph execution,
|
call_key: The key of the execution call, as a string. For graph execution,
|
||||||
|
|
|
@ -73,7 +73,7 @@ def guess_is_tensorflow_py_library(py_file_path):
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the extension name of py_file_path does not indicate a Python
|
ValueError: if the extension name of py_file_path does not indicate a Python
|
||||||
source file (compiled or uncomplied).
|
source file (compiled or uncompiled).
|
||||||
"""
|
"""
|
||||||
if (not is_extension_uncompiled_python_source(py_file_path) and
|
if (not is_extension_uncompiled_python_source(py_file_path) and
|
||||||
not is_extension_compiled_python_source(py_file_path)):
|
not is_extension_compiled_python_source(py_file_path)):
|
||||||
|
|
|
@ -341,7 +341,7 @@ class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):
|
||||||
# while/Less:0 4
|
# while/Less:0 4
|
||||||
# while/LoopCond:0 4
|
# while/LoopCond:0 4
|
||||||
# while/Switch:0 1
|
# while/Switch:0 1
|
||||||
# while/Swtich:1 3
|
# while/Switch:1 3
|
||||||
# while/Identity:0 3
|
# while/Identity:0 3
|
||||||
# while/Add/y:0 3
|
# while/Add/y:0 3
|
||||||
# while/Add:0 3
|
# while/Add:0 3
|
||||||
|
|
|
@ -90,7 +90,7 @@ class DumpingDebugWrapperSession(framework.NonInteractiveDebugWrapperSession):
|
||||||
self._run_counter_lock = threading.Lock()
|
self._run_counter_lock = threading.Lock()
|
||||||
|
|
||||||
def prepare_run_debug_urls(self, fetches, feed_dict):
|
def prepare_run_debug_urls(self, fetches, feed_dict):
|
||||||
"""Implementation of abstrat method in superclass.
|
"""Implementation of abstract method in superclass.
|
||||||
|
|
||||||
See doc of `NonInteractiveDebugWrapperSession.prepare_run_debug_urls()`
|
See doc of `NonInteractiveDebugWrapperSession.prepare_run_debug_urls()`
|
||||||
for details. This implementation creates a run-specific subdirectory under
|
for details. This implementation creates a run-specific subdirectory under
|
||||||
|
|
|
@ -44,7 +44,7 @@ c) (To be implemented in a future CL) Enter an instruction loop to let an
|
||||||
3) The callback handles the request and returns a OnSessionInitResponse
|
3) The callback handles the request and returns a OnSessionInitResponse
|
||||||
object with an action field, directing the wrapper session what to do next.
|
object with an action field, directing the wrapper session what to do next.
|
||||||
|
|
||||||
If the action field in the OnSessionInitResponse is PROCEED, the constuctor
|
If the action field in the OnSessionInitResponse is PROCEED, the constructor
|
||||||
returns. Control is released back to the caller of the constructor, which can
|
returns. Control is released back to the caller of the constructor, which can
|
||||||
invoke run() method of wrapper session with the same syntax as a non-wrapped
|
invoke run() method of wrapper session with the same syntax as a non-wrapped
|
||||||
session, e.g.,:
|
session, e.g.,:
|
||||||
|
@ -69,7 +69,7 @@ A1) Right at the start of each run() call, the on_run_start() callback is
|
||||||
|
|
||||||
A2) Right before the run() returns, the on_run_end() callback is invoked,
|
A2) Right before the run() returns, the on_run_end() callback is invoked,
|
||||||
with an OnRunEndRequest object as the argument, which carries information
|
with an OnRunEndRequest object as the argument, which carries information
|
||||||
including the actual action performed in the warpper run() call and the
|
including the actual action performed in the wrapper run() call and the
|
||||||
run_metadata from the run() call.
|
run_metadata from the run() call.
|
||||||
|
|
||||||
However, if the action field in OnSessionInitResponse is
|
However, if the action field in OnSessionInitResponse is
|
||||||
|
|
|
@ -393,7 +393,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||||
and caused the preparation of this run-end CLI (if any).
|
and caused the preparation of this run-end CLI (if any).
|
||||||
passed_filter_exclude_node_names: (None or str) Regular expression used
|
passed_filter_exclude_node_names: (None or str) Regular expression used
|
||||||
with the tensor filter to exclude ops with names matching the regular
|
with the tensor filter to exclude ops with names matching the regular
|
||||||
expresssion.
|
expression.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if tf_error:
|
if tf_error:
|
||||||
|
|
|
@ -345,7 +345,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||||
self.assertEqual(0, len(wrapped_sess.observers["debug_dumps"]))
|
self.assertEqual(0, len(wrapped_sess.observers["debug_dumps"]))
|
||||||
self.assertEqual([], wrapped_sess.observers["tf_errors"])
|
self.assertEqual([], wrapped_sess.observers["tf_errors"])
|
||||||
|
|
||||||
def testRunMixingDebugModeAndMultpleTimes(self):
|
def testRunMixingDebugModeAndMultipleTimes(self):
|
||||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||||
[["run", "-n"], ["run", "-t", "2"], ["run"], ["run"]],
|
[["run", "-n"], ["run", "-t", "2"], ["run"], ["run"]],
|
||||||
self.sess, dump_root=self._tmp_dir)
|
self.sess, dump_root=self._tmp_dir)
|
||||||
|
|
|
@ -519,7 +519,7 @@ def _build_recursive_hd_gather(input_tensors, devices, red_op):
|
||||||
|
|
||||||
|
|
||||||
def _build_recursive_hd_scatter(input_tensors, devices):
|
def _build_recursive_hd_scatter(input_tensors, devices):
|
||||||
"""Construct the scatter phase of recursive halving-doublng all-reduce.
|
"""Construct the scatter phase of recursive halving-doubling all-reduce.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensors: list of T `tf.Tensor` that are fully-reduced shards.
|
input_tensors: list of T `tf.Tensor` that are fully-reduced shards.
|
||||||
|
|
|
@ -190,7 +190,7 @@ class SlurmClusterResolver(ClusterResolver):
|
||||||
defaults to None.
|
defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A string specifying job name the process belongs to and an integner
|
A string specifying job name the process belongs to and an integer
|
||||||
specifying the task index the process belongs to in that job.
|
specifying the task index the process belongs to in that job.
|
||||||
"""
|
"""
|
||||||
return self.task_type, self.task_id
|
return self.task_type, self.task_id
|
||||||
|
@ -200,7 +200,7 @@ class SlurmClusterResolver(ClusterResolver):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_type: (Optional) Overrides the default auto-selected task type.
|
task_type: (Optional) Overrides the default auto-selected task type.
|
||||||
task_id: (Optional) Overrides the default auto-slected task index.
|
task_id: (Optional) Overrides the default auto-selected task index.
|
||||||
rpc_layer: (Optional) Overrides the default RPC protocol TensorFlow uses
|
rpc_layer: (Optional) Overrides the default RPC protocol TensorFlow uses
|
||||||
to communicate across nodes.
|
to communicate across nodes.
|
||||||
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||||
|
|
||||||
When 'TF_CONFIG' environment variable is set, it parses cluster_spec,
|
When 'TF_CONFIG' environment variable is set, it parses cluster_spec,
|
||||||
task_type and task_id from 'TF_CONFIG' and turns into a multi-worker strategy
|
task_type and task_id from 'TF_CONFIG' and turns into a multi-worker strategy
|
||||||
which mirrores models on GPUs of all machines in a cluster. In the current
|
which mirrored models on GPUs of all machines in a cluster. In the current
|
||||||
implementation, it uses all GPUs in a cluster and it assumes all workers have
|
implementation, it uses all GPUs in a cluster and it assumes all workers have
|
||||||
the same number of GPUs.
|
the same number of GPUs.
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ class CollectiveAllReduceStrategy(distribute_lib.Strategy):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_local_devices(cls, devices):
|
def _from_local_devices(cls, devices):
|
||||||
"""A convenience method to create an obejct with a list of devices."""
|
"""A convenience method to create an object with a list of devices."""
|
||||||
obj = cls()
|
obj = cls()
|
||||||
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
|
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
|
||||||
return obj
|
return obj
|
||||||
|
|
|
@ -435,7 +435,7 @@ def _group_value_by_device(per_replica_values):
|
||||||
]
|
]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
per_replica_values: a list of PerReplica obejcts.
|
per_replica_values: a list of PerReplica objects.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a list of lists, each sublist has components for its corresponding device of
|
a list of lists, each sublist has components for its corresponding device of
|
||||||
|
|
|
@ -761,7 +761,7 @@ def stitch_values(values_and_indices_list):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
values_and_indices_list: a list of tuples of values and indices indicating
|
values_and_indices_list: a list of tuples of values and indices indicating
|
||||||
the values and postions in the returned list.
|
the values and positions in the returned list.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a stitched list of values.
|
a stitched list of values.
|
||||||
|
|
|
@ -62,7 +62,7 @@ def canonicalize(d, default=None):
|
||||||
result = result.make_merged_spec(
|
result = result.make_merged_spec(
|
||||||
tf_device.DeviceSpec.from_string(default))
|
tf_device.DeviceSpec.from_string(default))
|
||||||
|
|
||||||
# Apply `d` last, so that it's values take precidence over the defaults.
|
# Apply `d` last, so that it's values take precedence over the defaults.
|
||||||
result = result.make_merged_spec(d)
|
result = result.make_merged_spec(d)
|
||||||
return result.to_string()
|
return result.to_string()
|
||||||
|
|
||||||
|
|
|
@ -284,12 +284,12 @@ class _WorkerContext(object):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_type(self):
|
def task_type(self):
|
||||||
"""Returns the role of the corresponing task."""
|
"""Returns the role of the corresponding task."""
|
||||||
return self._task_type
|
return self._task_type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def task_id(self):
|
def task_id(self):
|
||||||
"""Returns the id or index of the corresponing task."""
|
"""Returns the id or index of the corresponding task."""
|
||||||
return self._task_id
|
return self._task_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -364,7 +364,7 @@ def _split_cluster_for_evaluator(cluster_spec, task_type):
|
||||||
"""Split the cluster for evaluator since it needn't talk to other tasks."""
|
"""Split the cluster for evaluator since it needn't talk to other tasks."""
|
||||||
# Splitting the cluster is important to prevent the evaluator from talking to
|
# Splitting the cluster is important to prevent the evaluator from talking to
|
||||||
# other tasks in the cluster. Since we allow evaluator not to use
|
# other tasks in the cluster. Since we allow evaluator not to use
|
||||||
# distribution strategies and as a result ops in the evalauator task may have
|
# distribution strategies and as a result ops in the evaluator task may have
|
||||||
# unspecified devices. Those ops may end up on other tasks if we don't split
|
# unspecified devices. Those ops may end up on other tasks if we don't split
|
||||||
# the cluster.
|
# the cluster.
|
||||||
# Note: if you bypass distribute coordinator and bring the cluster yourself,
|
# Note: if you bypass distribute coordinator and bring the cluster yourself,
|
||||||
|
@ -694,7 +694,7 @@ def run_distribute_coordinator(worker_fn,
|
||||||
operations.
|
operations.
|
||||||
|
|
||||||
This method is intended to be invoked by high-level APIs so that users don't
|
This method is intended to be invoked by high-level APIs so that users don't
|
||||||
have to explictly call it to run this coordinator. For those who don't use
|
have to explicitly call it to run this coordinator. For those who don't use
|
||||||
high-level APIs, to change a program to use this coordinator, wrap everything
|
high-level APIs, to change a program to use this coordinator, wrap everything
|
||||||
in a the program after global data definitions such as commandline flag
|
in a the program after global data definitions such as commandline flag
|
||||||
definition into the `worker_fn` and get task-specific configurations from
|
definition into the `worker_fn` and get task-specific configurations from
|
||||||
|
|
|
@ -593,7 +593,7 @@ class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
|
||||||
("fake_evaluator", 3, True, False))
|
("fake_evaluator", 3, True, False))
|
||||||
|
|
||||||
|
|
||||||
class DistributeCoordinatorTestInpendentWorkerMode(
|
class DistributeCoordinatorTestIndependentWorkerMode(
|
||||||
DistributeCoordinatorTestBase):
|
DistributeCoordinatorTestBase):
|
||||||
|
|
||||||
def testInGraph(self):
|
def testInGraph(self):
|
||||||
|
@ -946,7 +946,7 @@ class RunStandardTensorflowServerTest(test.TestCase):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# TODO(yuefengz): find a smart way to terminite std server threads.
|
# TODO(yuefengz): find a smart way to terminate std server threads.
|
||||||
with test.mock.patch.object(sys, "exit", os._exit):
|
with test.mock.patch.object(sys, "exit", os._exit):
|
||||||
# Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly.
|
# Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly.
|
||||||
orig_init = session_manager.SessionManager.__init__
|
orig_init = session_manager.SessionManager.__init__
|
||||||
|
|
|
@ -1998,7 +1998,7 @@ class StrategyExtendedV1(StrategyExtendedV2):
|
||||||
- last_step_outputs: A dictionary containing tensors set using
|
- last_step_outputs: A dictionary containing tensors set using
|
||||||
`context.set_last_step_output`. Evaluating this returns the value of
|
`context.set_last_step_output`. Evaluating this returns the value of
|
||||||
the tensors after the last iteration.
|
the tensors after the last iteration.
|
||||||
- non_tensor_outputs: A dictionatry containing anything that was set by
|
- non_tensor_outputs: A dictionary containing anything that was set by
|
||||||
`fn` by calling `context.set_non_tensor_output`.
|
`fn` by calling `context.set_non_tensor_output`.
|
||||||
"""
|
"""
|
||||||
_require_cross_replica_or_default_context_extended(self)
|
_require_cross_replica_or_default_context_extended(self)
|
||||||
|
|
|
@ -158,7 +158,7 @@ def init_run_config(config, tf_config):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Don't use distribute coordinator if it is local training or cluster has a
|
# Don't use distribute coordinator if it is local training or cluster has a
|
||||||
# MASTER job or `train_distribute` is not specifed.
|
# MASTER job or `train_distribute` is not specified.
|
||||||
if (not cluster_spec or 'master' in cluster_spec.jobs or
|
if (not cluster_spec or 'master' in cluster_spec.jobs or
|
||||||
not config._train_distribute):
|
not config._train_distribute):
|
||||||
config._distribute_coordinator_mode = None
|
config._distribute_coordinator_mode = None
|
||||||
|
|
|
@ -176,7 +176,7 @@ def _get_next_as_optional(iterator, strategy, name=None):
|
||||||
with ops.device(worker):
|
with ops.device(worker):
|
||||||
worker_has_value, next_element = (
|
worker_has_value, next_element = (
|
||||||
iterator._iterators[i].get_next_as_list(new_name)) # pylint: disable=protected-access
|
iterator._iterators[i].get_next_as_list(new_name)) # pylint: disable=protected-access
|
||||||
# Collective all-reduce requires explict devices for inputs.
|
# Collective all-reduce requires explicit devices for inputs.
|
||||||
with ops.device("/cpu:0"):
|
with ops.device("/cpu:0"):
|
||||||
# Converting to integers for all-reduce.
|
# Converting to integers for all-reduce.
|
||||||
worker_has_value = math_ops.cast(worker_has_value, dtypes.int32)
|
worker_has_value = math_ops.cast(worker_has_value, dtypes.int32)
|
||||||
|
|
|
@ -260,7 +260,7 @@ def _group_device_list(devices):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a dict of list of device strings mapping from task_type to a list of devices
|
a dict of list of device strings mapping from task_type to a list of devices
|
||||||
for the task_type in the asceding order of task_id.
|
for the task_type in the ascending order of task_id.
|
||||||
"""
|
"""
|
||||||
assert not _is_device_list_single_worker(devices)
|
assert not _is_device_list_single_worker(devices)
|
||||||
device_dict = {}
|
device_dict = {}
|
||||||
|
|
|
@ -1319,7 +1319,7 @@ class MirroredVariableStopGradientTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
class FunctionTest(test.TestCase):
|
class FunctionTest(test.TestCase):
|
||||||
|
|
||||||
def testBackwardFuctionDevicePlacement(self):
|
def testBackwardFunctionDevicePlacement(self):
|
||||||
if context.num_gpus() < 1:
|
if context.num_gpus() < 1:
|
||||||
self.skipTest("At least one GPU is required.")
|
self.skipTest("At least one GPU is required.")
|
||||||
devices = [device_util.resolve("/device:GPU:0"),
|
devices = [device_util.resolve("/device:GPU:0"),
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""A base class to provid a model and corresponding input data for testing."""
|
"""A base class to provide a model and corresponding input data for testing."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
|
@ -42,7 +42,7 @@ def _get_data_for_simple_models():
|
||||||
|
|
||||||
|
|
||||||
class SimpleFunctionalModel(model_collection_base.ModelAndInput):
|
class SimpleFunctionalModel(model_collection_base.ModelAndInput):
|
||||||
"""A simple functinal model and its inputs."""
|
"""A simple functional model and its inputs."""
|
||||||
|
|
||||||
def get_model(self, **kwargs):
|
def get_model(self, **kwargs):
|
||||||
output_name = 'output_layer'
|
output_name = 'output_layer'
|
||||||
|
|
|
@ -550,7 +550,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||||
cluster_spec=None,
|
cluster_spec=None,
|
||||||
task_type=None,
|
task_type=None,
|
||||||
task_id=None):
|
task_id=None):
|
||||||
"""Configures the strategy class with `cluser_spec`.
|
"""Configures the strategy class with `cluster_spec`.
|
||||||
|
|
||||||
The strategy object will be re-initialized if `cluster_spec` is passed to
|
The strategy object will be re-initialized if `cluster_spec` is passed to
|
||||||
`configure` but was not passed when instantiating the strategy.
|
`configure` but was not passed when instantiating the strategy.
|
||||||
|
|
|
@ -186,7 +186,7 @@ class ParameterServerStrategyTestBase(
|
||||||
g = e + 1.0
|
g = e + 1.0
|
||||||
self.assertEqual(g.device, worker_device + '/device:CPU:1')
|
self.assertEqual(g.device, worker_device + '/device:CPU:1')
|
||||||
|
|
||||||
# Ths ops.colocate_with will be ignored when defining a variale but not
|
# Ths ops.colocate_with will be ignored when defining a variable but not
|
||||||
# for a normal tensor.
|
# for a normal tensor.
|
||||||
with ops.colocate_with(x):
|
with ops.colocate_with(x):
|
||||||
u = variable_scope.get_variable('u', initializer=30.0)
|
u = variable_scope.get_variable('u', initializer=30.0)
|
||||||
|
@ -340,7 +340,7 @@ class ParameterServerStrategyTestBase(
|
||||||
g = e + 1.0
|
g = e + 1.0
|
||||||
self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1'))
|
self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1'))
|
||||||
|
|
||||||
# Ths ops.colocate_with will be ignored when defining a variale but not
|
# Ths ops.colocate_with will be ignored when defining a variable but not
|
||||||
# for a normal tensor.
|
# for a normal tensor.
|
||||||
with ops.colocate_with(x):
|
with ops.colocate_with(x):
|
||||||
u = variable_scope.get_variable('u', initializer=30.0)
|
u = variable_scope.get_variable('u', initializer=30.0)
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Utilites for reduce operations."""
|
"""Utilities for reduce operations."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
|
@ -142,7 +142,7 @@ class TestSavedModelBase(test.TestCase, parameterized.TestCase):
|
||||||
def _save_model(self, model, saved_dir):
|
def _save_model(self, model, saved_dir):
|
||||||
"""Save the given model to the given saved_dir.
|
"""Save the given model to the given saved_dir.
|
||||||
|
|
||||||
This method needs to be implemeted by the subclasses.
|
This method needs to be implemented by the subclasses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: a keras model object to save.
|
model: a keras model object to save.
|
||||||
|
|
|
@ -1409,7 +1409,7 @@ class PerReplicaTest(test.TestCase, parameterized.TestCase):
|
||||||
vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),)
|
vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),)
|
||||||
per_replica = values.PerReplica(vals)
|
per_replica = values.PerReplica(vals)
|
||||||
|
|
||||||
# Note: nest.map_structutre exercises nest.flatten and
|
# Note: nest.map_structure exercises nest.flatten and
|
||||||
# nest.pack_sequence_as.
|
# nest.pack_sequence_as.
|
||||||
result = nest.map_structure(
|
result = nest.map_structure(
|
||||||
lambda t: t + 10, per_replica, expand_composites=True)
|
lambda t: t + 10, per_replica, expand_composites=True)
|
||||||
|
|
|
@ -25,7 +25,7 @@ To run a subset of benchmarks using --benchmarks flag.
|
||||||
--benchmarks: the list of benchmarks to run. The specified value is interpreted
|
--benchmarks: the list of benchmarks to run. The specified value is interpreted
|
||||||
as a regular expression and any benchmark whose name contains a partial match
|
as a regular expression and any benchmark whose name contains a partial match
|
||||||
to the regular expression is executed.
|
to the regular expression is executed.
|
||||||
e.g. --benchmarks=".*matmul*." will run all matmul related benmarks.
|
e.g. --benchmarks=".*matmul*." will run all matmul related benchmarks.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
|
@ -65,7 +65,7 @@ class _CallCounter(object):
|
||||||
break
|
break
|
||||||
|
|
||||||
def called_without_tracing(self):
|
def called_without_tracing(self):
|
||||||
# We don't count tracing when users load a concrete function dicretly or
|
# We don't count tracing when users load a concrete function directly or
|
||||||
# call get_concrete_function, so the first call can be not a tracing call.
|
# call get_concrete_function, so the first call can be not a tracing call.
|
||||||
if not self._calls_per_tracings:
|
if not self._calls_per_tracings:
|
||||||
self._calls_per_tracings = [0]
|
self._calls_per_tracings = [0]
|
||||||
|
@ -380,7 +380,7 @@ class Function(object):
|
||||||
tensorflow.autograph.Feature values. Allows enabling additional
|
tensorflow.autograph.Feature values. Allows enabling additional
|
||||||
conversion options when autograph is set to True.
|
conversion options when autograph is set to True.
|
||||||
experimental_relax_shapes: When true, argument shapes may be relaxed to
|
experimental_relax_shapes: When true, argument shapes may be relaxed to
|
||||||
avoid unecessary retracing.
|
avoid unnecessary retracing.
|
||||||
experimental_compile: If false, execute the function in a regular way. The
|
experimental_compile: If false, execute the function in a regular way. The
|
||||||
function is optimized by some graph rewrite passes (some ops might be
|
function is optimized by some graph rewrite passes (some ops might be
|
||||||
clustered into a single op) and interpreted by the standard TensorFlow
|
clustered into a single op) and interpreted by the standard TensorFlow
|
||||||
|
@ -728,7 +728,7 @@ class Function(object):
|
||||||
@function_lib.defun(autograph=False)
|
@function_lib.defun(autograph=False)
|
||||||
def initialize_variables():
|
def initialize_variables():
|
||||||
op_map = object_identity.ObjectIdentityDictionary()
|
op_map = object_identity.ObjectIdentityDictionary()
|
||||||
# Stack all the var_is_initialized values into one tensor and intepret the
|
# Stack all the var_is_initialized values into one tensor and interpret the
|
||||||
# numpy value. This will reduce the number of RPCs between client and
|
# numpy value. This will reduce the number of RPCs between client and
|
||||||
# worker in the remote case.
|
# worker in the remote case.
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
|
|
|
@ -936,7 +936,7 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
# NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this
|
# NOTE: assert_no_new_pyobjects_executing_eagerly fails flakily on this
|
||||||
# test... could be something wrong with the test decorator, or some sort of
|
# test... could be something wrong with the test decorator, or some sort of
|
||||||
# nondeterminstic caching.
|
# nondeterministic caching.
|
||||||
def testMirroredVariableWatched(self):
|
def testMirroredVariableWatched(self):
|
||||||
|
|
||||||
def _replicated(input_tangent):
|
def _replicated(input_tangent):
|
||||||
|
|
|
@ -65,7 +65,7 @@ def push_forwardprop_state():
|
||||||
temporarily reset its state. This is useful when building forwardprop versions
|
temporarily reset its state. This is useful when building forwardprop versions
|
||||||
of functions, where an accumulator will trigger function building and then
|
of functions, where an accumulator will trigger function building and then
|
||||||
must process captured symbolic tensors while building it. Without pushing and
|
must process captured symbolic tensors while building it. Without pushing and
|
||||||
poping, accumulators ignore operations executed as a direct result of their
|
popping, accumulators ignore operations executed as a direct result of their
|
||||||
own jvp computations.
|
own jvp computations.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
|
|
|
@ -358,7 +358,7 @@ def add_function_callback(function_callback):
|
||||||
|
|
||||||
wherein `function` is the just-created _EagerDefinedFunction.
|
wherein `function` is the just-created _EagerDefinedFunction.
|
||||||
The callback is invoked immediately after a new `_EagerDefinedFunction`
|
The callback is invoked immediately after a new `_EagerDefinedFunction`
|
||||||
is created. The return value(s) of the callback fucntion (if any) is ignored.
|
is created. The return value(s) of the callback function (if any) is ignored.
|
||||||
|
|
||||||
Repeated registration of the same callback function is idempotent.
|
Repeated registration of the same callback function is idempotent.
|
||||||
After a callback is added, it can be removed with the
|
After a callback is added, it can be removed with the
|
||||||
|
@ -850,7 +850,7 @@ class _DelayedRewriteGradientFunctions(object):
|
||||||
higher-order symbolic gradients (tf.gradients).
|
higher-order symbolic gradients (tf.gradients).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flat_outputs: The restult of running `forward`.
|
flat_outputs: The result of running `forward`.
|
||||||
inference_args: A flat list of Tensors with inference inputs to the
|
inference_args: A flat list of Tensors with inference inputs to the
|
||||||
operation.
|
operation.
|
||||||
input_tangents: A flat list of Tensors with input tangents consumed by the
|
input_tangents: A flat list of Tensors with input tangents consumed by the
|
||||||
|
@ -1314,7 +1314,7 @@ class _TapeGradientFunctions(object):
|
||||||
have produced tangents which need to be recorded.
|
have produced tangents which need to be recorded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flat_outputs: The restult of running `forward`.
|
flat_outputs: The result of running `forward`.
|
||||||
inference_args: A flat list of Tensors with inference inputs to the
|
inference_args: A flat list of Tensors with inference inputs to the
|
||||||
operation.
|
operation.
|
||||||
input_tangents: A flat list of Tensors with input tangents consumed by the
|
input_tangents: A flat list of Tensors with input tangents consumed by the
|
||||||
|
@ -1757,7 +1757,7 @@ class ConcreteFunction(object):
|
||||||
return self._build_call_outputs(flat_outputs)
|
return self._build_call_outputs(flat_outputs)
|
||||||
|
|
||||||
def _experimental_with_cancellation_manager(self, cancellation_manager):
|
def _experimental_with_cancellation_manager(self, cancellation_manager):
|
||||||
"""Returns a callable that invokes a cancelable version of this function.
|
"""Returns a callable that invokes a cancellable version of this function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cancellation_manager: A `CancellationManager` object that can be used to
|
cancellation_manager: A `CancellationManager` object that can be used to
|
||||||
|
@ -2376,7 +2376,7 @@ class Function(object):
|
||||||
`when autograph=True`. See https://www.tensorflow.org/guide/autograph
|
`when autograph=True`. See https://www.tensorflow.org/guide/autograph
|
||||||
for more information.
|
for more information.
|
||||||
experimental_relax_shapes: When true, argument shapes may be relaxed to
|
experimental_relax_shapes: When true, argument shapes may be relaxed to
|
||||||
avoid unecessary retracing.
|
avoid unnecessary retracing.
|
||||||
capture_by_value: Experimental. Whether to capture resource variables by
|
capture_by_value: Experimental. Whether to capture resource variables by
|
||||||
value or reference. If None, will inherit from a parent context or
|
value or reference. If None, will inherit from a parent context or
|
||||||
default to False.
|
default to False.
|
||||||
|
@ -2668,7 +2668,7 @@ class Function(object):
|
||||||
return graph_function
|
return graph_function
|
||||||
|
|
||||||
def _define_function_with_shape_relaxation(self, args, kwargs):
|
def _define_function_with_shape_relaxation(self, args, kwargs):
|
||||||
"""Define a function, relaxing arg shapes to avoid unecessary retracing."""
|
"""Define a function, relaxing arg shapes to avoid unnecessary retracing."""
|
||||||
|
|
||||||
rank_only_cache_key = self._cache_key(
|
rank_only_cache_key = self._cache_key(
|
||||||
args, kwargs, include_tensor_ranks_only=True)
|
args, kwargs, include_tensor_ranks_only=True)
|
||||||
|
@ -2824,7 +2824,7 @@ def defun(func=None,
|
||||||
the values of its non-Tensor Python objects.
|
the values of its non-Tensor Python objects.
|
||||||
|
|
||||||
When eager execution is enabled, the ability to create graphs from Python
|
When eager execution is enabled, the ability to create graphs from Python
|
||||||
functions makes it possible to incrementally trade off debugability and
|
functions makes it possible to incrementally trade off debuggability and
|
||||||
interactivity for performance. Functions compiled with `defun` cannot be
|
interactivity for performance. Functions compiled with `defun` cannot be
|
||||||
inspected with `pdb`; however, executing a graph
|
inspected with `pdb`; however, executing a graph
|
||||||
generated by `defun` sometimes takes less time and memory than eagerly
|
generated by `defun` sometimes takes less time and memory than eagerly
|
||||||
|
@ -3130,7 +3130,7 @@ def defun(func=None,
|
||||||
of tensorflow.autograph.Feature values) to control behavior when
|
of tensorflow.autograph.Feature values) to control behavior when
|
||||||
autograph=True.
|
autograph=True.
|
||||||
experimental_relax_shapes: When true, argument shapes may be relaxed to
|
experimental_relax_shapes: When true, argument shapes may be relaxed to
|
||||||
avoid unecessary retracing.
|
avoid unnecessary retracing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If `func` is not None, returns a callable that will execute the compiled
|
If `func` is not None, returns a callable that will execute the compiled
|
||||||
|
|
|
@ -3508,7 +3508,7 @@ class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
||||||
self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0)
|
self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0)
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def testArgumentPrunning(self):
|
def testArgumentPruning(self):
|
||||||
"""Tests functions taking unnecessary arguments."""
|
"""Tests functions taking unnecessary arguments."""
|
||||||
with ops.device('/device:CPU:0'):
|
with ops.device('/device:CPU:0'):
|
||||||
c1 = constant_op.constant(5.0)
|
c1 = constant_op.constant(5.0)
|
||||||
|
|
|
@ -359,7 +359,7 @@ typedef struct EagerTensor {
|
||||||
TFE_TensorHandle* handle;
|
TFE_TensorHandle* handle;
|
||||||
int64_t id;
|
int64_t id;
|
||||||
// This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
|
// This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
|
||||||
// be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
|
// be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE
|
||||||
// tensors, this will contain a serialized HandleData proto with shape
|
// tensors, this will contain a serialized HandleData proto with shape
|
||||||
// inference metadata about shapes and dtypes of resources accessible from
|
// inference metadata about shapes and dtypes of resources accessible from
|
||||||
// this handle.
|
// this handle.
|
||||||
|
@ -660,7 +660,7 @@ static PyObject* EagerTensor_backing_device(EagerTensor* self) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyGetSetDef EagerTensor_getseters[] = {
|
static PyGetSetDef EagerTensor_getsetters[] = {
|
||||||
{const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
|
{const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
|
||||||
const_cast<char*>("Tensor ID."), nullptr},
|
const_cast<char*>("Tensor ID."), nullptr},
|
||||||
{const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
|
{const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
|
||||||
|
@ -758,7 +758,7 @@ PyTypeObject* EagerTensorType = nullptr;
|
||||||
static PyType_Slot EagerTensor_Type_slots[] = {
|
static PyType_Slot EagerTensor_Type_slots[] = {
|
||||||
{Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
|
{Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
|
||||||
{Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
|
{Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
|
||||||
{Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getseters)},
|
{Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getsetters)},
|
||||||
{Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
|
{Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
|
||||||
{0, nullptr},
|
{0, nullptr},
|
||||||
};
|
};
|
||||||
|
@ -799,7 +799,7 @@ static PyTypeObject _EagerTensorType = {
|
||||||
nullptr, /* tp_iternext */
|
nullptr, /* tp_iternext */
|
||||||
EagerTensor_methods, /* tp_methods */
|
EagerTensor_methods, /* tp_methods */
|
||||||
EagerTensor_members, /* tp_members */
|
EagerTensor_members, /* tp_members */
|
||||||
EagerTensor_getseters, /* tp_getset */
|
EagerTensor_getsetters, /* tp_getset */
|
||||||
nullptr, /* tp_base */
|
nullptr, /* tp_base */
|
||||||
nullptr, /* tp_dict */
|
nullptr, /* tp_dict */
|
||||||
nullptr, /* tp_descr_get */
|
nullptr, /* tp_descr_get */
|
||||||
|
|
|
@ -197,7 +197,7 @@ PyObject* TFE_Py_TapeSetIsStopped();
|
||||||
// forwardprop to, given the gradients of the output tensors, produce the
|
// forwardprop to, given the gradients of the output tensors, produce the
|
||||||
// gradients of the input tensors. This function is automatically transposed
|
// gradients of the input tensors. This function is automatically transposed
|
||||||
// during forwardprop.
|
// during forwardprop.
|
||||||
// - forward_function is an optional special-case for fowardprop, taking input
|
// - forward_function is an optional special-case for forwardprop, taking input
|
||||||
// jvps and returning output jvps.
|
// jvps and returning output jvps.
|
||||||
//
|
//
|
||||||
// Records an operation both for backprop (gradient tape) and forwardprop
|
// Records an operation both for backprop (gradient tape) and forwardprop
|
||||||
|
@ -307,7 +307,7 @@ PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, PyObject* tensor);
|
||||||
// temporarily reset its state. This is useful when building forwardprop
|
// temporarily reset its state. This is useful when building forwardprop
|
||||||
// versions of functions, where an accumulator will trigger function building
|
// versions of functions, where an accumulator will trigger function building
|
||||||
// and then must process captured symbolic tensors while building it. Without
|
// and then must process captured symbolic tensors while building it. Without
|
||||||
// pushing and poping, accumulators ignore operations executed as a direct
|
// pushing and popping, accumulators ignore operations executed as a direct
|
||||||
// result of their own jvp computations.
|
// result of their own jvp computations.
|
||||||
PyObject* TFE_Py_ForwardAccumulatorPushState();
|
PyObject* TFE_Py_ForwardAccumulatorPushState();
|
||||||
PyObject* TFE_Py_ForwardAccumulatorPopState();
|
PyObject* TFE_Py_ForwardAccumulatorPopState();
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Gradient tape utilites."""
|
"""Gradient tape utilities."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
|
@ -270,7 +270,7 @@ class WrappedFunction(function.ConcreteFunction):
|
||||||
tensor_fetches = []
|
tensor_fetches = []
|
||||||
tensor_infos = []
|
tensor_infos = []
|
||||||
|
|
||||||
def _fetch_preprocesing_callback(fetch):
|
def _fetch_preprocessing_callback(fetch):
|
||||||
"""Extract out lists of ops, tensors, and tensor type info.
|
"""Extract out lists of ops, tensors, and tensor type info.
|
||||||
|
|
||||||
Turns TensorInfos into Tensors in the original `fetches` structure.
|
Turns TensorInfos into Tensors in the original `fetches` structure.
|
||||||
|
@ -300,9 +300,9 @@ class WrappedFunction(function.ConcreteFunction):
|
||||||
return fetch
|
return fetch
|
||||||
else:
|
else:
|
||||||
graph_element = self.graph.as_graph_element(fetch)
|
graph_element = self.graph.as_graph_element(fetch)
|
||||||
return _fetch_preprocesing_callback(graph_element)
|
return _fetch_preprocessing_callback(graph_element)
|
||||||
|
|
||||||
fetches = nest.map_structure(_fetch_preprocesing_callback, fetches)
|
fetches = nest.map_structure(_fetch_preprocessing_callback, fetches)
|
||||||
|
|
||||||
# Expand composite tensors into their component dense Tensors.
|
# Expand composite tensors into their component dense Tensors.
|
||||||
tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
|
tensor_fetches = nest.flatten(tensor_fetches, expand_composites=True)
|
||||||
|
|
|
@ -63,7 +63,7 @@ def _initialized_session(config=None):
|
||||||
|
|
||||||
class LazyColumnTest(test.TestCase):
|
class LazyColumnTest(test.TestCase):
|
||||||
|
|
||||||
def test_transormations_called_once(self):
|
def test_transformations_called_once(self):
|
||||||
|
|
||||||
class TransformCounter(_FeatureColumn):
|
class TransformCounter(_FeatureColumn):
|
||||||
|
|
||||||
|
@ -1131,7 +1131,7 @@ class CrossedColumnTest(test.TestCase):
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
"""Tests linear_model.
|
"""Tests linear_model.
|
||||||
|
|
||||||
Uses data from test_get_sparse_tesnsors_simple.
|
Uses data from test_get_sparse_tensors_simple.
|
||||||
"""
|
"""
|
||||||
a = fc._numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
a = fc._numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
b = fc._bucketized_column(a, boundaries=(0, 1))
|
b = fc._bucketized_column(a, boundaries=(0, 1))
|
||||||
|
@ -1213,7 +1213,7 @@ class CrossedColumnTest(test.TestCase):
|
||||||
def test_keras_linear_model(self):
|
def test_keras_linear_model(self):
|
||||||
"""Tests _LinearModel.
|
"""Tests _LinearModel.
|
||||||
|
|
||||||
Uses data from test_get_sparse_tesnsors_simple.
|
Uses data from test_get_sparse_tensors_simple.
|
||||||
"""
|
"""
|
||||||
a = fc._numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
a = fc._numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
b = fc._bucketized_column(a, boundaries=(0, 1))
|
b = fc._bucketized_column(a, boundaries=(0, 1))
|
||||||
|
|
|
@ -683,7 +683,7 @@ class LinearModel(training.Model):
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
with `sparse_combiner` as "mean", the linear model outputs conceptly are
|
with `sparse_combiner` as "mean", the linear model outputs conceptually are
|
||||||
```
|
```
|
||||||
y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
|
y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
|
||||||
y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
|
y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
|
||||||
|
|
|
@ -1369,7 +1369,7 @@ class CrossedColumnTest(test.TestCase):
|
||||||
def test_linear_model(self):
|
def test_linear_model(self):
|
||||||
"""Tests linear_model.
|
"""Tests linear_model.
|
||||||
|
|
||||||
Uses data from test_get_sparse_tesnsors_simple.
|
Uses data from test_get_sparse_tensors_simple.
|
||||||
"""
|
"""
|
||||||
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
b = fc.bucketized_column(a, boundaries=(0, 1))
|
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||||
|
@ -1463,7 +1463,7 @@ class CrossedColumnTest(test.TestCase):
|
||||||
def test_old_linear_model(self):
|
def test_old_linear_model(self):
|
||||||
"""Tests linear_model.
|
"""Tests linear_model.
|
||||||
|
|
||||||
Uses data from test_get_sparse_tesnsors_simple.
|
Uses data from test_get_sparse_tensors_simple.
|
||||||
"""
|
"""
|
||||||
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
b = fc.bucketized_column(a, boundaries=(0, 1))
|
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||||
|
@ -1573,7 +1573,7 @@ class CrossedColumnTest(test.TestCase):
|
||||||
def test_old_linear_model_old_numeric(self):
|
def test_old_linear_model_old_numeric(self):
|
||||||
"""Tests linear_model.
|
"""Tests linear_model.
|
||||||
|
|
||||||
Uses data from test_get_sparse_tesnsors_simple.
|
Uses data from test_get_sparse_tensors_simple.
|
||||||
"""
|
"""
|
||||||
a = fc_old._numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
a = fc_old._numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||||
b = fc.bucketized_column(a, boundaries=(0, 1))
|
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||||
|
|
|
@ -197,7 +197,7 @@ def _column_name_with_class_name(fc):
|
||||||
Without this two FeatureColumns that have the same name and where
|
Without this two FeatureColumns that have the same name and where
|
||||||
one wraps the other, such as an IndicatorColumn wrapping a
|
one wraps the other, such as an IndicatorColumn wrapping a
|
||||||
SequenceCategoricalColumn, will fail to deserialize because they will have the
|
SequenceCategoricalColumn, will fail to deserialize because they will have the
|
||||||
same name in colums_by_name, causing the wrong column to be returned.
|
same name in columns_by_name, causing the wrong column to be returned.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fc: A FeatureColumn.
|
fc: A FeatureColumn.
|
||||||
|
|
|
@ -614,7 +614,7 @@ class DeviceTest(test.TestCase):
|
||||||
self.assertIsNotNone(gpu.name)
|
self.assertIsNotNone(gpu.name)
|
||||||
|
|
||||||
@reset_eager
|
@reset_eager
|
||||||
def testV1CompatibilityDummyInivisibleDeviceList(self):
|
def testV1CompatibilityDummyInvisibleDeviceList(self):
|
||||||
gpus = config.list_physical_devices('GPU')
|
gpus = config.list_physical_devices('GPU')
|
||||||
if gpus:
|
if gpus:
|
||||||
self.skipTest('Test requires no GPUs')
|
self.skipTest('Test requires no GPUs')
|
||||||
|
|
|
@ -249,7 +249,7 @@ def _get_control_flow_function_data(node_defs, tensor_data, name_to_node):
|
||||||
|
|
||||||
def get_source_node_name_through_identities(node_name):
|
def get_source_node_name_through_identities(node_name):
|
||||||
# Trace the source node along with a chain of Identity nodes.
|
# Trace the source node along with a chain of Identity nodes.
|
||||||
# For example, given Plaecholder -> Identity -> Identity -> node_name
|
# For example, given Placeholder -> Identity -> Identity -> node_name
|
||||||
# The function will return the name of the Placeholder.
|
# The function will return the name of the Placeholder.
|
||||||
while name_to_node[node_name].op == "Identity":
|
while name_to_node[node_name].op == "Identity":
|
||||||
node_name = _get_tensor_name(name_to_node[node_name].input[0])
|
node_name = _get_tensor_name(name_to_node[node_name].input[0])
|
||||||
|
|
|
@ -212,7 +212,7 @@ class DeviceSpecV2(object):
|
||||||
def make_merged_spec(self, dev):
|
def make_merged_spec(self, dev):
|
||||||
"""Returns a new DeviceSpec which incorporates `dev`.
|
"""Returns a new DeviceSpec which incorporates `dev`.
|
||||||
|
|
||||||
When combining specs, `dev` will take precidence over the current spec.
|
When combining specs, `dev` will take precedence over the current spec.
|
||||||
So for instance:
|
So for instance:
|
||||||
```
|
```
|
||||||
first_spec = tf.DeviceSpec(job=0, device_type="CPU")
|
first_spec = tf.DeviceSpec(job=0, device_type="CPU")
|
||||||
|
@ -253,7 +253,7 @@ class DeviceSpecV2(object):
|
||||||
job=self.job, replica=self.replica, task=self.task,
|
job=self.job, replica=self.replica, task=self.task,
|
||||||
device_type=self.device_type, device_index=self.device_index)
|
device_type=self.device_type, device_index=self.device_index)
|
||||||
|
|
||||||
# Explicitly provided kwargs take precidence.
|
# Explicitly provided kwargs take precedence.
|
||||||
init_kwargs.update(kwargs)
|
init_kwargs.update(kwargs)
|
||||||
return self.__class__(**init_kwargs)
|
return self.__class__(**init_kwargs)
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ def convert_structure_to_signature(structure, arg_names=None):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Identical structure that has TensorSpec objects instead of Tensors and
|
Identical structure that has TensorSpec objects instead of Tensors and
|
||||||
UknownArgument instead of any unsupported types.
|
UnknownArgument instead of any unsupported types.
|
||||||
"""
|
"""
|
||||||
def encode_arg(arg, path):
|
def encode_arg(arg, path):
|
||||||
"""A representation for this argument, for converting into signatures."""
|
"""A representation for this argument, for converting into signatures."""
|
||||||
|
@ -1197,7 +1197,7 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None):
|
||||||
"either zero or all names have to be specified.")
|
"either zero or all names have to be specified.")
|
||||||
|
|
||||||
for arg in flattened:
|
for arg in flattened:
|
||||||
# We have a shape entry for each arg, regadless of whether it's a real
|
# We have a shape entry for each arg, regardless of whether it's a real
|
||||||
# Tensor or not. For non-tensor entries it should be None.
|
# Tensor or not. For non-tensor entries it should be None.
|
||||||
shape = next(shapes_iter)
|
shape = next(shapes_iter)
|
||||||
if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
|
if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
|
||||||
|
|
|
@ -290,7 +290,7 @@ class _DefinedFunction(object):
|
||||||
device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
|
device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
# Get the innermost device if possbile.
|
# Get the innermost device if possible.
|
||||||
self._caller_device = device_funcs[-1] if device_funcs else None
|
self._caller_device = device_funcs[-1] if device_funcs else None
|
||||||
|
|
||||||
# Cached OpDef for this function. When C API is enabled, this is
|
# Cached OpDef for this function. When C API is enabled, this is
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
"""Utlity to convert FunctionDef to GraphDef and Graph."""
|
"""Utility to convert FunctionDef to GraphDef and Graph."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
|
@ -124,7 +124,7 @@ def _node_name(n):
|
||||||
|
|
||||||
|
|
||||||
def _get_colocated_node_name(colocated_node_name):
|
def _get_colocated_node_name(colocated_node_name):
|
||||||
"""Decodes colocated node name and returns it without loc:@ preprended."""
|
"""Decodes colocated node name and returns it without loc:@ prepended."""
|
||||||
colocated_node_decoded = colocated_node_name.decode("utf-8")
|
colocated_node_decoded = colocated_node_name.decode("utf-8")
|
||||||
if colocated_node_decoded.startswith("loc:@"):
|
if colocated_node_decoded.startswith("loc:@"):
|
||||||
return colocated_node_decoded[5:]
|
return colocated_node_decoded[5:]
|
||||||
|
|
|
@ -1045,7 +1045,7 @@ def export_scoped_meta_graph(filename=None,
|
||||||
name, _ = os.path.splitext(filename)
|
name, _ = os.path.splitext(filename)
|
||||||
debug_filename = "{name}{ext}".format(name=name, ext=".debug")
|
debug_filename = "{name}{ext}".format(name=name, ext=".debug")
|
||||||
|
|
||||||
# Gets the operation from the graph by the name. Exludes variable nodes,
|
# Gets the operation from the graph by the name. Excludes variable nodes,
|
||||||
# so only the nodes in the frozen models are included.
|
# so only the nodes in the frozen models are included.
|
||||||
# TODO(liufengdb): fix this for functions.
|
# TODO(liufengdb): fix this for functions.
|
||||||
ops_to_export = []
|
ops_to_export = []
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue