Replacing nest.is_sequence with nest.is_nested
PiperOrigin-RevId: 322206369 Change-Id: I21ea9f922b0c3c874aa50bd58b2eb1b1fba313f9
This commit is contained in:
parent
67f914e794
commit
800da0d15a
tensorflow/python/keras
@ -4105,9 +4105,9 @@ def rnn(step_function,
|
||||
# That's what the tile call does, it just repeats the mask along its
|
||||
# second dimension n times.
|
||||
def _expand_mask(mask_t, input_t, fixed_dim=1):
|
||||
if nest.is_sequence(mask_t):
|
||||
if nest.is_nested(mask_t):
|
||||
raise ValueError('mask_t is expected to be tensor, but got %s' % mask_t)
|
||||
if nest.is_sequence(input_t):
|
||||
if nest.is_nested(input_t):
|
||||
raise ValueError('input_t is expected to be tensor, but got %s' % input_t)
|
||||
rank_diff = len(input_t.shape) - len(mask_t.shape)
|
||||
for _ in range(rank_diff):
|
||||
@ -4133,7 +4133,7 @@ def rnn(step_function,
|
||||
input_t.reverse()
|
||||
return input_t
|
||||
|
||||
if nest.is_sequence(inputs):
|
||||
if nest.is_nested(inputs):
|
||||
processed_input = nest.map_structure(_process_single_input_t, inputs)
|
||||
else:
|
||||
processed_input = (_process_single_input_t(inputs),)
|
||||
|
@ -62,7 +62,7 @@ class Container(object):
|
||||
struct = map_to_output_names(outputs, self._output_names, struct)
|
||||
struct = map_missing_dict_keys(outputs, struct)
|
||||
# Allow passing one object that applies to all outputs.
|
||||
if not nest.is_sequence(struct) and nest.is_sequence(outputs):
|
||||
if not nest.is_nested(struct) and nest.is_nested(outputs):
|
||||
struct = nest.map_structure(lambda _: struct, outputs)
|
||||
return struct
|
||||
|
||||
@ -267,7 +267,7 @@ class LossesContainer(Container):
|
||||
return loss
|
||||
|
||||
def _should_broadcast(self, obj):
|
||||
return not nest.is_sequence(obj)
|
||||
return not nest.is_nested(obj)
|
||||
|
||||
def _copy_object(self, obj):
|
||||
return obj # Losses don't need to be copied.
|
||||
@ -478,11 +478,11 @@ class MetricsContainer(Container):
|
||||
|
||||
def _should_broadcast(self, obj):
|
||||
# e.g. 'mse'.
|
||||
if not nest.is_sequence(obj):
|
||||
if not nest.is_nested(obj):
|
||||
return True
|
||||
# e.g. ['mse'] or ['mse', 'mae'].
|
||||
return (isinstance(obj, (list, tuple)) and
|
||||
not any(nest.is_sequence(o) for o in obj))
|
||||
not any(nest.is_nested(o) for o in obj))
|
||||
|
||||
def _copy_object(self, obj):
|
||||
if isinstance(obj, metrics_mod.Metric):
|
||||
@ -572,10 +572,10 @@ def map_to_output_names(y_pred, output_names, struct):
|
||||
Returns:
|
||||
`struct` mapped to a list in same order as `output_names`.
|
||||
"""
|
||||
single_output = not nest.is_sequence(y_pred)
|
||||
single_output = not nest.is_nested(y_pred)
|
||||
outputs_are_flat_list = (not single_output and
|
||||
isinstance(y_pred, (list, tuple)) and
|
||||
not any(nest.is_sequence(y_p) for y_p in y_pred))
|
||||
not any(nest.is_nested(y_p) for y_p in y_pred))
|
||||
|
||||
if (single_output or outputs_are_flat_list) and isinstance(struct, dict):
|
||||
output_names = output_names or create_pseudo_output_names(y_pred)
|
||||
|
@ -1300,7 +1300,7 @@ def _make_class_weight_map_fn(class_weight):
|
||||
"""Convert `class_weight` to `sample_weight`."""
|
||||
x, y, sw = unpack_x_y_sample_weight(data)
|
||||
|
||||
if nest.is_sequence(y):
|
||||
if nest.is_nested(y):
|
||||
raise ValueError(
|
||||
"`class_weight` is only supported for Models with a single output.")
|
||||
|
||||
@ -1496,7 +1496,7 @@ def pack_x_y_sample_weight(x, y=None, sample_weight=None):
|
||||
# there is no ambiguity. This also makes NumPy and Dataset
|
||||
# consistent in that the user does not have to wrap their Dataset
|
||||
# data in an unecessary tuple
|
||||
if not nest.is_sequence(x):
|
||||
if not nest.is_nested(x):
|
||||
return x
|
||||
else:
|
||||
return (x,)
|
||||
|
@ -130,9 +130,9 @@ class Functional(training_lib.Model):
|
||||
# be called with a dict, where the keys of the dict are the names
|
||||
# of the `Input` objects. Extra keys are ignored with warning.
|
||||
self._enable_dict_to_input_mapping = (
|
||||
not nest.is_sequence(self._nested_inputs) or
|
||||
not nest.is_nested(self._nested_inputs) or
|
||||
(isinstance(self._nested_inputs, (list, tuple, dict)) and
|
||||
not any(nest.is_sequence(t) for t in self._nested_inputs)))
|
||||
not any(nest.is_nested(t) for t in self._nested_inputs)))
|
||||
|
||||
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
||||
base_layer_utils.create_keras_history(self._nested_outputs)
|
||||
@ -519,7 +519,7 @@ class Functional(training_lib.Model):
|
||||
"""Maps `tensors` to their respective `keras.Input`."""
|
||||
if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
|
||||
ref_inputs = self._nested_inputs
|
||||
if not nest.is_sequence(ref_inputs):
|
||||
if not nest.is_nested(ref_inputs):
|
||||
ref_inputs = [self._nested_inputs]
|
||||
if isinstance(ref_inputs, dict):
|
||||
# In the case that the graph is constructed with dict input tensors,
|
||||
@ -1289,7 +1289,7 @@ def get_network_config(network, serialize_layer_fn=None):
|
||||
tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
|
||||
model_inputs = nest.pack_sequence_as(network._nested_inputs, model_inputs)
|
||||
# Preserve external Keras compat for Models with single input.
|
||||
if not nest.is_sequence(model_inputs):
|
||||
if not nest.is_nested(model_inputs):
|
||||
model_inputs = [model_inputs]
|
||||
model_inputs = tf_utils.convert_inner_node_data(model_inputs)
|
||||
config['input_layers'] = model_inputs
|
||||
@ -1305,7 +1305,7 @@ def get_network_config(network, serialize_layer_fn=None):
|
||||
tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
|
||||
model_outputs = nest.pack_sequence_as(network._nested_outputs, model_outputs)
|
||||
# Preserve external Keras compat for Models with single output.
|
||||
if not nest.is_sequence(model_outputs):
|
||||
if not nest.is_nested(model_outputs):
|
||||
model_outputs = [model_outputs]
|
||||
model_outputs = tf_utils.convert_inner_node_data(model_outputs)
|
||||
config['output_layers'] = model_outputs
|
||||
|
@ -198,7 +198,7 @@ class Node(object):
|
||||
return tf_utils.ListWrapper(data)
|
||||
|
||||
data = nest.map_structure(serialize_first_arg_tensor, inputs)
|
||||
if not nest.is_sequence(data):
|
||||
if not nest.is_nested(data):
|
||||
data = [data]
|
||||
data = tf_utils.convert_inner_node_data(data)
|
||||
return data
|
||||
|
@ -1235,7 +1235,7 @@ class MultiRNNCell(RNNCell):
|
||||
super(MultiRNNCell, self).__init__()
|
||||
if not cells:
|
||||
raise ValueError("Must specify at least one cell for MultiRNNCell.")
|
||||
if not nest.is_sequence(cells):
|
||||
if not nest.is_nested(cells):
|
||||
raise TypeError("cells must be a list or tuple, but saw: %s." % cells)
|
||||
|
||||
if len(set(id(cell) for cell in cells)) < len(cells):
|
||||
@ -1252,7 +1252,7 @@ class MultiRNNCell(RNNCell):
|
||||
self._track_trackable(cell, name="cell-%d" % (cell_number,))
|
||||
self._state_is_tuple = state_is_tuple
|
||||
if not state_is_tuple:
|
||||
if any(nest.is_sequence(c.state_size) for c in self._cells):
|
||||
if any(nest.is_nested(c.state_size) for c in self._cells):
|
||||
raise ValueError("Some cells return tuples of states, but the flag "
|
||||
"state_is_tuple is not set. State sizes are: %s" %
|
||||
str([c.state_size for c in self._cells]))
|
||||
@ -1309,7 +1309,7 @@ class MultiRNNCell(RNNCell):
|
||||
for i, cell in enumerate(self._cells):
|
||||
with vs.variable_scope("cell_%d" % i):
|
||||
if self._state_is_tuple:
|
||||
if not nest.is_sequence(state):
|
||||
if not nest.is_nested(state):
|
||||
raise ValueError(
|
||||
"Expected state to be a tuple of length %d, but received: %s" %
|
||||
(len(self.state_size), state))
|
||||
|
@ -139,7 +139,7 @@ class StackedRNNCells(Layer):
|
||||
# Call the cells in order and store the returned states.
|
||||
new_nested_states = []
|
||||
for cell, states in zip(self.cells, nested_states):
|
||||
states = states if nest.is_sequence(states) else [states]
|
||||
states = states if nest.is_nested(states) else [states]
|
||||
# TF cell does not wrap the state into list when there is only one state.
|
||||
is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
|
||||
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
|
||||
@ -448,7 +448,7 @@ class RNN(Layer):
|
||||
def states(self):
|
||||
if self._states is None:
|
||||
state = nest.map_structure(lambda _: None, self.cell.state_size)
|
||||
return state if nest.is_sequence(self.cell.state_size) else [state]
|
||||
return state if nest.is_nested(self.cell.state_size) else [state]
|
||||
return self._states
|
||||
|
||||
@states.setter
|
||||
@ -559,7 +559,7 @@ class RNN(Layer):
|
||||
# A nested tensor input
|
||||
pass
|
||||
|
||||
if not nest.is_sequence(input_shape):
|
||||
if not nest.is_nested(input_shape):
|
||||
# This indicates the there is only one input.
|
||||
if self.input_spec is not None:
|
||||
self.input_spec[0] = get_input_spec(input_shape)
|
||||
@ -632,7 +632,7 @@ class RNN(Layer):
|
||||
def get_initial_state(self, inputs):
|
||||
get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
|
||||
|
||||
if nest.is_sequence(inputs):
|
||||
if nest.is_nested(inputs):
|
||||
# The input are nested sequences. Use the first element in the seq to get
|
||||
# batch size and dtype.
|
||||
inputs = nest.flatten(inputs)[0]
|
||||
@ -647,7 +647,7 @@ class RNN(Layer):
|
||||
init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
|
||||
dtype)
|
||||
# Keras RNN expect the states in a list, even if it's a single state tensor.
|
||||
if not nest.is_sequence(init_state):
|
||||
if not nest.is_nested(init_state):
|
||||
init_state = [init_state]
|
||||
# Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.
|
||||
return list(init_state)
|
||||
@ -743,7 +743,7 @@ class RNN(Layer):
|
||||
# TODO(scottzhu): Should we accept multiple different masks?
|
||||
mask = nest.flatten(mask)[0]
|
||||
|
||||
if nest.is_sequence(inputs):
|
||||
if nest.is_nested(inputs):
|
||||
# In the case of nested input, use the first element for shape check.
|
||||
input_shape = K.int_shape(nest.flatten(inputs)[0])
|
||||
else:
|
||||
@ -782,7 +782,7 @@ class RNN(Layer):
|
||||
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
|
||||
output, new_states = cell_call_fn(
|
||||
inputs, states, constants=constants, **kwargs)
|
||||
if not nest.is_sequence(new_states):
|
||||
if not nest.is_nested(new_states):
|
||||
new_states = [new_states]
|
||||
return output, new_states
|
||||
else:
|
||||
@ -790,7 +790,7 @@ class RNN(Layer):
|
||||
def step(inputs, states):
|
||||
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
|
||||
output, new_states = cell_call_fn(inputs, states, **kwargs)
|
||||
if not nest.is_sequence(new_states):
|
||||
if not nest.is_nested(new_states):
|
||||
new_states = [new_states]
|
||||
return output, new_states
|
||||
last_output, outputs, states = K.rnn(
|
||||
@ -929,7 +929,7 @@ class RNN(Layer):
|
||||
return K.zeros([batch_size] + tensor_shape.TensorShape(state).as_list())
|
||||
self.states = nest.map_structure(
|
||||
create_state_variable, self.cell.state_size)
|
||||
if not nest.is_sequence(self.states):
|
||||
if not nest.is_nested(self.states):
|
||||
self.states = [self.states]
|
||||
elif states is None:
|
||||
for state, size in zip(nest.flatten(self.states),
|
||||
@ -1359,7 +1359,7 @@ class SimpleRNNCell(DropoutRNNCellMixin, Layer):
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, states, training=None):
|
||||
prev_output = states[0] if nest.is_sequence(states) else states
|
||||
prev_output = states[0] if nest.is_nested(states) else states
|
||||
dp_mask = self.get_dropout_mask_for_cell(inputs, training)
|
||||
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
|
||||
prev_output, training)
|
||||
@ -1377,7 +1377,7 @@ class SimpleRNNCell(DropoutRNNCellMixin, Layer):
|
||||
if self.activation is not None:
|
||||
output = self.activation(output)
|
||||
|
||||
new_state = [output] if nest.is_sequence(states) else output
|
||||
new_state = [output] if nest.is_nested(states) else output
|
||||
return output, new_state
|
||||
|
||||
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
|
||||
@ -1819,7 +1819,7 @@ class GRUCell(DropoutRNNCellMixin, Layer):
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs, states, training=None):
|
||||
h_tm1 = states[0] if nest.is_sequence(states) else states # previous memory
|
||||
h_tm1 = states[0] if nest.is_nested(states) else states # previous memory
|
||||
|
||||
dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
|
||||
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
|
||||
@ -1917,7 +1917,7 @@ class GRUCell(DropoutRNNCellMixin, Layer):
|
||||
hh = self.activation(x_h + recurrent_h)
|
||||
# previous and candidate state mixed by update gate
|
||||
h = z * h_tm1 + (1 - z) * hh
|
||||
new_state = [h] if nest.is_sequence(states) else h
|
||||
new_state = [h] if nest.is_nested(states) else h
|
||||
return h, new_state
|
||||
|
||||
def get_config(self):
|
||||
@ -3020,7 +3020,7 @@ def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
|
||||
init_state_size = [batch_size_tensor] + flat_dims
|
||||
return array_ops.zeros(init_state_size, dtype=dtype)
|
||||
|
||||
if nest.is_sequence(state_size):
|
||||
if nest.is_nested(state_size):
|
||||
return nest.map_structure(create_zeros, state_size)
|
||||
else:
|
||||
return create_zeros(state_size)
|
||||
|
@ -174,7 +174,7 @@ def map_structure_with_atomic(is_atomic_fn, map_fn, nested):
|
||||
return map_fn(nested)
|
||||
|
||||
# Recursively convert.
|
||||
if not nest.is_sequence(nested):
|
||||
if not nest.is_nested(nested):
|
||||
raise ValueError(
|
||||
'Received non-atomic and non-sequence element: {}'.format(nested))
|
||||
if nest._is_mapping(nested):
|
||||
@ -284,7 +284,7 @@ def convert_inner_node_data(nested, wrap=False):
|
||||
return True
|
||||
if _is_serialized_node_data(nested):
|
||||
return True
|
||||
return not nest.is_sequence(nested)
|
||||
return not nest.is_nested(nested)
|
||||
|
||||
def _convert_object_or_list(nested):
|
||||
"""Convert b/t `ListWrapper` object and list representations."""
|
||||
|
Loading…
Reference in New Issue
Block a user