Enable support for loop directives, include shape_invariants.
PiperOrigin-RevId: 285966483 Change-Id: I3eae0b134cd2e954bfa0ac31e6a7411b3a5bb7df
This commit is contained in:
parent
1768c8f2fa
commit
e42b937e1a
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import gast
|
import gast
|
||||||
|
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
|
from tensorflow.python.autograph.lang import directives
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
from tensorflow.python.autograph.pyct import ast_util
|
from tensorflow.python.autograph.pyct import ast_util
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
@ -151,6 +152,20 @@ class ControlFlowTransformer(converter.Base):
|
|||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
def _create_loop_options(self, node):
|
||||||
|
if not anno.hasanno(node, anno.Basic.DIRECTIVES):
|
||||||
|
return gast.Dict([], [])
|
||||||
|
|
||||||
|
loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
|
||||||
|
if directives.set_loop_options not in loop_directives:
|
||||||
|
return gast.Dict([], [])
|
||||||
|
|
||||||
|
opts_dict = loop_directives[directives.set_loop_options]
|
||||||
|
str_keys, values = zip(*opts_dict.items())
|
||||||
|
keys = [gast.Str(s) for s in str_keys]
|
||||||
|
values = list(values) # ast and gast don't play well with tuples.
|
||||||
|
return gast.Dict(keys, values)
|
||||||
|
|
||||||
def _create_undefined_assigns(self, undefined_symbols):
|
def _create_undefined_assigns(self, undefined_symbols):
|
||||||
assignments = []
|
assignments = []
|
||||||
for s in undefined_symbols:
|
for s in undefined_symbols:
|
||||||
@ -383,8 +398,7 @@ class ControlFlowTransformer(converter.Base):
|
|||||||
composite_symbol_names = tuple(
|
composite_symbol_names = tuple(
|
||||||
gast.Str(str(symbol)) for symbol in composite_loop_vars)
|
gast.Str(str(symbol)) for symbol in composite_loop_vars)
|
||||||
|
|
||||||
# TODO(b/140125096): Populate.
|
opts = self._create_loop_options(node)
|
||||||
opts = gast.Dict([], [])
|
|
||||||
|
|
||||||
# TODO(mdan): Use a single template.
|
# TODO(mdan): Use a single template.
|
||||||
# If the body and test functions took a single tuple for loop_vars, instead
|
# If the body and test functions took a single tuple for loop_vars, instead
|
||||||
@ -507,8 +521,7 @@ class ControlFlowTransformer(converter.Base):
|
|||||||
composite_symbol_names = tuple(
|
composite_symbol_names = tuple(
|
||||||
gast.Str(str(symbol)) for symbol in composite_loop_vars)
|
gast.Str(str(symbol)) for symbol in composite_loop_vars)
|
||||||
|
|
||||||
# TODO(b/140125096): Populate.
|
opts = self._create_loop_options(node)
|
||||||
opts = gast.Dict([], [])
|
|
||||||
|
|
||||||
# TODO(mdan): Use a single template.
|
# TODO(mdan): Use a single template.
|
||||||
# If the body and test functions took a single tuple for loop_vars, instead
|
# If the body and test functions took a single tuple for loop_vars, instead
|
||||||
|
@ -26,7 +26,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
from tensorflow.tools.docs.doc_controls import do_not_generate_docs
|
|
||||||
|
|
||||||
UNSPECIFIED = object()
|
UNSPECIFIED = object()
|
||||||
|
|
||||||
@ -47,37 +46,53 @@ def set_element_type(entity, dtype, shape=UNSPECIFIED):
|
|||||||
del shape
|
del shape
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/140125096): Implement.
|
|
||||||
@do_not_generate_docs
|
|
||||||
@tf_export('autograph.experimental.set_loop_options')
|
@tf_export('autograph.experimental.set_loop_options')
|
||||||
def set_loop_options(
|
def set_loop_options(
|
||||||
parallel_iterations=UNSPECIFIED,
|
parallel_iterations=UNSPECIFIED,
|
||||||
back_prop=UNSPECIFIED,
|
|
||||||
swap_memory=UNSPECIFIED,
|
swap_memory=UNSPECIFIED,
|
||||||
maximum_iterations=UNSPECIFIED):
|
maximum_iterations=UNSPECIFIED,
|
||||||
|
shape_invariants=UNSPECIFIED):
|
||||||
"""Specifies additional arguments to be passed to the enclosing while_loop.
|
"""Specifies additional arguments to be passed to the enclosing while_loop.
|
||||||
|
|
||||||
The parameters apply to and only to the immediately enclosing loop. It only
|
The parameters apply to and only to the immediately enclosing loop. It only
|
||||||
has effect if the loop is staged as a TF while_loop; otherwise the parameters
|
has effect if the loop is staged as a TF while_loop; otherwise the parameters
|
||||||
have no effect.
|
have no effect.
|
||||||
|
|
||||||
Usage example:
|
Usage:
|
||||||
|
|
||||||
@tf.function(autograph=True)
|
>>> @tf.function(autograph=True)
|
||||||
def dynamic_rnn(..., parallel_iterations=32):
|
... def f():
|
||||||
num_steps = ...
|
... n = 0
|
||||||
for t in tf.range(num_steps):
|
... for i in tf.range(10):
|
||||||
tf.autograph.experimental.set_loop_options(
|
... tf.autograph.experimental.set_loop_options(maximum_iterations=3)
|
||||||
parallel_iterations=parallel_iterations)
|
... n += 1
|
||||||
...
|
... return n
|
||||||
|
|
||||||
|
>>> @tf.function(autograph=True)
|
||||||
|
... def f():
|
||||||
|
... v = tf.constant((0,))
|
||||||
|
... for i in tf.range(3):
|
||||||
|
... tf.autograph.experimental.set_loop_options(
|
||||||
|
... shape_invariants=[(v, tf.TensorShape([None]))]
|
||||||
|
... )
|
||||||
|
... v = tf.concat((v, [i]), 0)
|
||||||
|
... return v
|
||||||
|
|
||||||
|
Also see tf.while_loop.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
parallel_iterations: See tf.while_loop.
|
parallel_iterations: The maximum number of iterations allowed to run in
|
||||||
back_prop: See tf.while_loop.
|
parallel at any given time. Note that this does not guarantee parallel
|
||||||
swap_memory: See tf.while_loop.
|
execution.
|
||||||
maximum_iterations: See tf.while_loop.
|
swap_memory: Whether to store intermediate values needed for
|
||||||
|
gradients on the CPU instead of GPU.
|
||||||
|
maximum_iterations: Allows limiting the total number of iterations executed
|
||||||
|
by the loop.
|
||||||
|
shape_invariants: Allows controlling the argument with the same name passed
|
||||||
|
to tf.while_loop. Unlike tf.while_loop, this is a list of
|
||||||
|
`(tensor, shape)` pairs.
|
||||||
"""
|
"""
|
||||||
del parallel_iterations
|
del parallel_iterations
|
||||||
del back_prop
|
|
||||||
del swap_memory
|
del swap_memory
|
||||||
del maximum_iterations
|
del maximum_iterations
|
||||||
|
del shape_invariants
|
||||||
|
@ -125,68 +125,91 @@ def _is_subshape(left, right):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _verify_single_loop_var(name, check_shape, init_loop_var, first_iter_var):
|
# TODO(mdan): Remove these verifications once TF ops can properly report names.
|
||||||
"""Verifies whether init_loop_var and first_iter_var are consistent."""
|
def _verify_single_loop_var(
|
||||||
if isinstance(init_loop_var, (bool, int, float, str)):
|
name, check_shape, init, entry, exit_, shape_invariant):
|
||||||
init_loop_var = ops.convert_to_tensor_v2(init_loop_var)
|
"""Verifies whether the initial, entry and exit values are consistent."""
|
||||||
|
if isinstance(init, (bool, int, float, str, np.ndarray)):
|
||||||
|
init = ops.convert_to_tensor_v2(init)
|
||||||
|
if isinstance(entry, (bool, int, float, str, np.ndarray)):
|
||||||
|
entry = ops.convert_to_tensor_v2(entry)
|
||||||
|
if isinstance(exit_, (bool, int, float, str)):
|
||||||
|
exit_ = ops.convert_to_tensor_v2(exit_)
|
||||||
|
|
||||||
if isinstance(first_iter_var, (bool, int, float, str)):
|
if (not tensor_util.is_tensor(entry) or
|
||||||
first_iter_var = ops.convert_to_tensor_v2(first_iter_var)
|
not tensor_util.is_tensor(exit_)):
|
||||||
|
|
||||||
if (not tensor_util.is_tensor(init_loop_var) or
|
|
||||||
not tensor_util.is_tensor(first_iter_var)):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# TODO(mdan): Properly account for CompositeTensors.
|
# TODO(mdan): Properly account for CompositeTensors.
|
||||||
if (not hasattr(init_loop_var, 'dtype') or
|
if (not hasattr(entry, 'dtype') or
|
||||||
not hasattr(first_iter_var, 'dtype')):
|
not hasattr(exit_, 'dtype')):
|
||||||
return
|
return
|
||||||
if (not hasattr(init_loop_var, 'shape') or
|
if (not hasattr(entry, 'shape') or
|
||||||
not hasattr(first_iter_var, 'shape')):
|
not hasattr(exit_, 'shape')):
|
||||||
return
|
return
|
||||||
|
|
||||||
if init_loop_var.dtype != first_iter_var.dtype:
|
if entry.dtype != exit_.dtype:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'"{}" has dtype {} before the loop, but dtype {} after one'
|
'"{}" has dtype {} before the loop, but dtype {} after one'
|
||||||
' iteration. TensorFlow control flow requires it stays the'
|
' iteration. TensorFlow control flow requires it stays the'
|
||||||
' same.'.format(
|
' same.'.format(
|
||||||
name,
|
name,
|
||||||
init_loop_var.dtype.name,
|
entry.dtype.name,
|
||||||
first_iter_var.dtype.name,
|
exit_.dtype.name,
|
||||||
))
|
))
|
||||||
|
|
||||||
if check_shape:
|
if check_shape:
|
||||||
init_shape = init_loop_var.shape
|
exit_shape = exit_.shape
|
||||||
first_iter_shape = first_iter_var.shape
|
if shape_invariant is None:
|
||||||
# TODO(b/135183013): Update needed once we support shape_invariants.
|
entry_shape = entry.shape
|
||||||
if not _is_subshape(first_iter_shape, init_shape):
|
if not _is_subshape(exit_shape, entry_shape):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'"{}" has shape {} before the loop, but shape {} after one'
|
'"{}" has shape {} before the loop, but shape {} after one'
|
||||||
' iteration. TensorFlow control flow requires it stays the'
|
' iteration. Use tf.autograph.experimental.set_loop_options to set'
|
||||||
' same or be more specific.'.format(name, init_shape,
|
' shape invariants.'.format(name, entry_shape, exit_shape))
|
||||||
first_iter_shape))
|
else:
|
||||||
|
init_shape = init.shape
|
||||||
|
if not _is_subshape(init_shape, shape_invariant):
|
||||||
|
raise ValueError(
|
||||||
|
'"{}" has shape {} before the loop, which does not conform with'
|
||||||
|
' the shape invariant {}.'.format(name, init_shape,
|
||||||
|
shape_invariant))
|
||||||
|
if not _is_subshape(exit_shape, shape_invariant):
|
||||||
|
raise ValueError(
|
||||||
|
'"{}" has shape {} after the loop, which does not conform with'
|
||||||
|
' the shape invariant {}.'.format(
|
||||||
|
name, exit_shape, shape_invariant))
|
||||||
|
|
||||||
|
|
||||||
def _verify_tf_loop_vars(init_loop_vars,
|
def _verify_tf_loop_vars(init_vars,
|
||||||
first_iter_vars,
|
iter_entry_vars,
|
||||||
|
iter_exit_vars,
|
||||||
symbol_names,
|
symbol_names,
|
||||||
opts,
|
opts,
|
||||||
check_shapes=True):
|
check_shapes=True):
|
||||||
"""Verifies loop variables for consistency."""
|
"""Verifies loop variables for consistency."""
|
||||||
# TODO(b/140125096): Use this.
|
if check_shapes and 'shape_invariants' in opts:
|
||||||
del opts
|
shape_invariants = opts['shape_invariants']
|
||||||
|
else:
|
||||||
|
shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars)
|
||||||
|
|
||||||
named_vars = zip(symbol_names, init_loop_vars, first_iter_vars)
|
named_vars = zip(symbol_names, init_vars, iter_entry_vars, iter_exit_vars,
|
||||||
for name, init_loop_var, first_iter_var in named_vars:
|
shape_invariants)
|
||||||
|
for name, init, entry, exit_, invariant in named_vars:
|
||||||
try:
|
try:
|
||||||
nest.assert_same_structure(
|
nest.assert_same_structure(entry, exit_, expand_composites=True)
|
||||||
init_loop_var, first_iter_var, expand_composites=True)
|
|
||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
raise TypeError('"{}" does not have the same nested structure after one'
|
raise TypeError('"{}" does not have the same nested structure after one'
|
||||||
' iteration.\n\n{}'.format(name, e))
|
' iteration.\n\n{}'.format(name, e))
|
||||||
|
if invariant is not None:
|
||||||
|
try:
|
||||||
|
nest.assert_same_structure(init, invariant, expand_composites=False)
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
raise TypeError('"{}" does not have the same nested structure as its'
|
||||||
|
' corresponding shape invariant.\n\n{}'.format(name, e))
|
||||||
|
|
||||||
nest.map_structure(
|
nest.map_structure(
|
||||||
functools.partial(_verify_single_loop_var, name, check_shapes),
|
functools.partial(_verify_single_loop_var, name, check_shapes), init,
|
||||||
init_loop_var, first_iter_var)
|
entry, exit_, invariant)
|
||||||
|
|
||||||
|
|
||||||
def _verify_single_cond_var(name, body_var, orelse_var):
|
def _verify_single_cond_var(name, body_var, orelse_var):
|
||||||
@ -425,6 +448,8 @@ def _tf_ragged_for_stmt(iter_,
|
|||||||
else:
|
else:
|
||||||
n = iter_.row_lengths()[0]
|
n = iter_.row_lengths()[0]
|
||||||
|
|
||||||
|
opts['maximum_iterations'] = n
|
||||||
|
|
||||||
def while_body(iterate_index, *loop_vars):
|
def while_body(iterate_index, *loop_vars):
|
||||||
"""Main loop body."""
|
"""Main loop body."""
|
||||||
iterate = iter_[iterate_index]
|
iterate = iter_[iterate_index]
|
||||||
@ -566,7 +591,7 @@ def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state,
|
|||||||
# Note: this verification duplicates that perfrmed in tf_while_stmt,
|
# Note: this verification duplicates that perfrmed in tf_while_stmt,
|
||||||
# but needs to be done earlier to prevent the tf.cond inside while_body
|
# but needs to be done earlier to prevent the tf.cond inside while_body
|
||||||
# from blowing up first.
|
# from blowing up first.
|
||||||
_verify_tf_loop_vars(loop_vars, new_vars,
|
_verify_tf_loop_vars(init_vars, loop_vars, new_vars,
|
||||||
basic_symbol_names + composite_symbol_names, opts)
|
basic_symbol_names + composite_symbol_names, opts)
|
||||||
return new_vars
|
return new_vars
|
||||||
|
|
||||||
@ -653,20 +678,26 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
|
|||||||
|
|
||||||
# TODO(mdan): Simplify this - following it is extremely difficult.
|
# TODO(mdan): Simplify this - following it is extremely difficult.
|
||||||
|
|
||||||
|
init_state = get_state()
|
||||||
|
aug_init_vars = init_vars, init_state
|
||||||
|
|
||||||
def scan_body(aug_vars, iterate):
|
def scan_body(aug_vars, iterate):
|
||||||
"""The main loop body wrapper. Only calculates the stop condition."""
|
"""The main loop body wrapper. Only calculates the stop condition."""
|
||||||
loop_vars, state = aug_vars
|
loop_vars, state = aug_vars
|
||||||
|
|
||||||
def true_fn():
|
def true_fn():
|
||||||
|
"""Main path - stop condition is not set."""
|
||||||
set_state(state)
|
set_state(state)
|
||||||
outputs = body(iterate, *loop_vars)
|
new_vars = body(iterate, *loop_vars)
|
||||||
|
new_state = get_state()
|
||||||
_verify_tf_loop_vars(
|
_verify_tf_loop_vars(
|
||||||
|
init_vars + init_state,
|
||||||
loop_vars + state,
|
loop_vars + state,
|
||||||
outputs + state,
|
new_vars + new_state,
|
||||||
basic_symbol_names + composite_symbol_names,
|
basic_symbol_names + composite_symbol_names,
|
||||||
opts,
|
opts,
|
||||||
check_shapes=False)
|
check_shapes=False)
|
||||||
return outputs, get_state()
|
return new_vars, new_state
|
||||||
|
|
||||||
extra_cond = extra_test(*loop_vars)
|
extra_cond = extra_test(*loop_vars)
|
||||||
new_vars, new_state = control_flow_ops.cond(
|
new_vars, new_state = control_flow_ops.cond(
|
||||||
@ -690,11 +721,9 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
|
|||||||
del extra_cond
|
del extra_cond
|
||||||
return output_aug_vars, output_state
|
return output_aug_vars, output_state
|
||||||
|
|
||||||
init_state = get_state()
|
ds = _general_purpose_scan(ds, aug_init_vars, scan_body)
|
||||||
aug_vars = init_vars, init_state
|
|
||||||
ds = _general_purpose_scan(ds, aug_vars, scan_body)
|
|
||||||
ds = ds.apply(take_while_ops.take_while(take_while_predicate))
|
ds = ds.apply(take_while_ops.take_while(take_while_predicate))
|
||||||
final_aug_vars = ds.reduce(aug_vars, reduce_body)
|
final_aug_vars = ds.reduce(aug_init_vars, reduce_body)
|
||||||
final_vars, final_state = final_aug_vars
|
final_vars, final_state = final_aug_vars
|
||||||
set_state(final_state)
|
set_state(final_state)
|
||||||
return final_vars
|
return final_vars
|
||||||
@ -741,6 +770,7 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
|
|||||||
new_state = get_state()
|
new_state = get_state()
|
||||||
|
|
||||||
_verify_tf_loop_vars(
|
_verify_tf_loop_vars(
|
||||||
|
init_vars + init_state,
|
||||||
loop_vars + state,
|
loop_vars + state,
|
||||||
new_vars + new_state,
|
new_vars + new_state,
|
||||||
symbol_names,
|
symbol_names,
|
||||||
@ -824,11 +854,23 @@ def while_stmt(test,
|
|||||||
return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
|
return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
|
||||||
|
|
||||||
|
|
||||||
|
def _shape_invariants_mapping_to_positional_list(mapping, keys):
|
||||||
|
# The keys are not expected to be hashable.
|
||||||
|
mapping = {id(k): (k, v) for k, v in mapping}
|
||||||
|
result = []
|
||||||
|
for k in keys:
|
||||||
|
map_key, map_val = mapping.get(id(k), (None, None))
|
||||||
|
result.append(map_val if map_key is k else None)
|
||||||
|
return tuple(result)
|
||||||
|
|
||||||
|
|
||||||
def _tf_while_stmt(test, body, get_state, set_state, init_vars,
|
def _tf_while_stmt(test, body, get_state, set_state, init_vars,
|
||||||
basic_symbol_names, composite_symbol_names, opts):
|
basic_symbol_names, composite_symbol_names, opts):
|
||||||
"""Overload of while_stmt that stages a TF while_stmt."""
|
"""Overload of while_stmt that stages a TF while_stmt."""
|
||||||
_disallow_undefs_into_loop(*init_vars)
|
_disallow_undefs_into_loop(*init_vars)
|
||||||
|
|
||||||
|
aug_init_vars = init_vars + get_state()
|
||||||
|
|
||||||
# TODO(mdan): Simplify this.
|
# TODO(mdan): Simplify this.
|
||||||
loop_vars_slice = slice(len(init_vars))
|
loop_vars_slice = slice(len(init_vars))
|
||||||
state_slice = slice(len(init_vars), None)
|
state_slice = slice(len(init_vars), None)
|
||||||
@ -844,7 +886,7 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars,
|
|||||||
set_state(state)
|
set_state(state)
|
||||||
loop_vars = body(*aug_loop_vars[loop_vars_slice])
|
loop_vars = body(*aug_loop_vars[loop_vars_slice])
|
||||||
new_state = loop_vars + get_state()
|
new_state = loop_vars + get_state()
|
||||||
_verify_tf_loop_vars(aug_loop_vars, new_state,
|
_verify_tf_loop_vars(aug_init_vars, aug_loop_vars, new_state,
|
||||||
basic_symbol_names + composite_symbol_names, opts)
|
basic_symbol_names + composite_symbol_names, opts)
|
||||||
|
|
||||||
return new_state
|
return new_state
|
||||||
@ -853,7 +895,10 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars,
|
|||||||
# This enforces consistency across versions.
|
# This enforces consistency across versions.
|
||||||
opts['return_same_structure'] = True
|
opts['return_same_structure'] = True
|
||||||
|
|
||||||
aug_init_vars = init_vars + get_state()
|
if 'shape_invariants' in opts:
|
||||||
|
opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
|
||||||
|
opts['shape_invariants'], aug_init_vars)
|
||||||
|
|
||||||
final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body,
|
final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body,
|
||||||
aug_init_vars, **opts)
|
aug_init_vars, **opts)
|
||||||
final_state = final_aug_vars[state_slice]
|
final_state = final_aug_vars[state_slice]
|
||||||
|
@ -503,13 +503,16 @@ def _shape_invariant_to_type_spec(var, shape):
|
|||||||
Returns:
|
Returns:
|
||||||
A `TypeSpec` for `var`, consistent with the given shape.
|
A `TypeSpec` for `var`, consistent with the given shape.
|
||||||
"""
|
"""
|
||||||
if isinstance(shape, type_spec.TypeSpec):
|
if shape is None:
|
||||||
|
return type_spec.type_spec_from_value(var)
|
||||||
|
elif isinstance(shape, type_spec.TypeSpec):
|
||||||
if not shape.is_compatible_with(var):
|
if not shape.is_compatible_with(var):
|
||||||
raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
|
raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
|
||||||
return shape
|
return shape
|
||||||
elif not isinstance(shape, tensor_shape.TensorShape):
|
elif not isinstance(shape, tensor_shape.TensorShape):
|
||||||
raise TypeError("Expected shape to be a TypeSpec or TensorShape, got %r"
|
raise TypeError(
|
||||||
% shape)
|
"Expected shape to be a TypeSpec, TensorShape or None, got %r for"
|
||||||
|
" value %r" % (shape, var))
|
||||||
|
|
||||||
if isinstance(var, ops.Tensor):
|
if isinstance(var, ops.Tensor):
|
||||||
return tensor_spec.TensorSpec(shape, var.dtype)
|
return tensor_spec.TensorSpec(shape, var.dtype)
|
||||||
|
@ -10,6 +10,6 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "set_loop_options"
|
name: "set_loop_options"
|
||||||
argspec: "args=[\'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
|
argspec: "args=[\'parallel_iterations\', \'swap_memory\', \'maximum_iterations\', \'shape_invariants\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,6 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "set_loop_options"
|
name: "set_loop_options"
|
||||||
argspec: "args=[\'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
|
argspec: "args=[\'parallel_iterations\', \'swap_memory\', \'maximum_iterations\', \'shape_invariants\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user