TFTS: Move normalization to the base class, start using it for state space models
Preivously, state space models adjusted their priors based on the data (e.g. setting initial variances to match sample variance) but did not normalize the data itself. When the data has a rather extreme scale, this runs into precision issues. After this CL, state space models will first normalize, then use adjusted statistics on top of that normalization to estimate initial observation/transition noise. Also fixes an issue where start-of-series statistics were incorrect for the first batch (which only shows up with large input scales). PiperOrigin-RevId: 171044863
This commit is contained in:
parent
266f771563
commit
558d878d91
@ -106,16 +106,6 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
|
||||
for state_element
|
||||
in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)])
|
||||
|
||||
def _transform(self, data):
|
||||
"""Normalize data based on input statistics to encourage stable training."""
|
||||
mean, variance = self._input_statistics.overall_feature_moments
|
||||
return (data - mean) / variance
|
||||
|
||||
def _de_transform(self, data):
|
||||
"""Transform data back to the input scale."""
|
||||
mean, variance = self._input_statistics.overall_feature_moments
|
||||
return data * variance + mean
|
||||
|
||||
def _filtering_step(self, current_times, current_values, state, predictions):
|
||||
"""Update model state based on observations.
|
||||
|
||||
@ -140,7 +130,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
|
||||
state_from_time, prediction, lstm_state = state
|
||||
with tf.control_dependencies(
|
||||
[tf.assert_equal(current_times, state_from_time)]):
|
||||
transformed_values = self._transform(current_values)
|
||||
# Subtract the mean and divide by the variance of the series. Slightly
|
||||
# more efficient if done for a whole window (using the normalize_features
|
||||
# argument to SequentialTimeSeriesModel).
|
||||
transformed_values = self._scale_data(current_values)
|
||||
# Use mean squared error across features for the loss.
|
||||
predictions["loss"] = tf.reduce_mean(
|
||||
(prediction - transformed_values) ** 2, axis=-1)
|
||||
@ -156,7 +149,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
|
||||
inputs=previous_observation_or_prediction, state=lstm_state)
|
||||
next_prediction = self._predict_from_lstm_output(lstm_output)
|
||||
new_state_tuple = (current_times, next_prediction, new_lstm_state)
|
||||
return new_state_tuple, {"mean": self._de_transform(next_prediction)}
|
||||
return new_state_tuple, {"mean": self._scale_back_data(next_prediction)}
|
||||
|
||||
def _imputation_step(self, current_times, state):
|
||||
"""Advance model state across a gap."""
|
||||
|
@ -89,8 +89,6 @@ class ARModel(model.TimeSeriesModel):
|
||||
self.hidden_layer_sizes = hidden_layer_sizes
|
||||
self.window_size = self.input_window_size + self.output_window_size
|
||||
self.loss = loss
|
||||
self.stats_means = None
|
||||
self.stats_sigmas = None
|
||||
super(ARModel, self).__init__(
|
||||
num_features=num_features)
|
||||
assert num_time_buckets > 0
|
||||
@ -106,32 +104,6 @@ class ARModel(model.TimeSeriesModel):
|
||||
assert len(self._periods) or self.input_window_size
|
||||
assert output_window_size > 0
|
||||
|
||||
def scale_data(self, data):
|
||||
"""Scale data according to stats."""
|
||||
if self._input_statistics is not None:
|
||||
return (data - self.stats_means) / self.stats_sigmas
|
||||
else:
|
||||
return data
|
||||
|
||||
def scale_back_data(self, data):
|
||||
if self._input_statistics is not None:
|
||||
return (data * self.stats_sigmas) + self.stats_means
|
||||
else:
|
||||
return data
|
||||
|
||||
def scale_back_variance(self, var):
|
||||
if self._input_statistics is not None:
|
||||
return var * self.stats_sigmas * self.stats_sigmas
|
||||
else:
|
||||
return var
|
||||
|
||||
def initialize_graph(self, input_statistics=None):
|
||||
super(ARModel, self).initialize_graph(input_statistics=input_statistics)
|
||||
if self._input_statistics:
|
||||
self.stats_means, variances = (
|
||||
self._input_statistics.overall_feature_moments)
|
||||
self.stats_sigmas = math_ops.sqrt(variances)
|
||||
|
||||
def get_start_state(self):
|
||||
# State which matches the format we'll return later. Typically this will not
|
||||
# be used by the model directly, but the shapes and dtypes should match so
|
||||
@ -388,8 +360,8 @@ class ARModel(model.TimeSeriesModel):
|
||||
predicted_covariance = array_ops.ones_like(predicted_mean)
|
||||
|
||||
# Transform and scale the mean and covariance appropriately.
|
||||
predicted_mean = self.scale_back_data(predicted_mean)
|
||||
predicted_covariance = self.scale_back_variance(predicted_covariance)
|
||||
predicted_mean = self._scale_back_data(predicted_mean)
|
||||
predicted_covariance = self._scale_back_variance(predicted_covariance)
|
||||
|
||||
return {"mean": predicted_mean,
|
||||
"covariance": predicted_covariance}
|
||||
@ -418,7 +390,7 @@ class ARModel(model.TimeSeriesModel):
|
||||
times_feature=TrainEvalFeatures.TIMES,
|
||||
window_size=self.window_size,
|
||||
times_shape=times.get_shape()))
|
||||
values = self.scale_data(values)
|
||||
values = self._scale_data(values)
|
||||
if self.input_window_size > 0:
|
||||
input_values = values[:, :self.input_window_size, :]
|
||||
else:
|
||||
@ -435,14 +407,14 @@ class ARModel(model.TimeSeriesModel):
|
||||
# (observed - predicted) ** 2.
|
||||
# Note that this affects only evaluation; the training loss is unaffected.
|
||||
loss = self.loss_op(
|
||||
self.scale_back_data(targets),
|
||||
{"mean": self.scale_back_data(prediction_ops["mean"])})
|
||||
self._scale_back_data(targets),
|
||||
{"mean": self._scale_back_data(prediction_ops["mean"])})
|
||||
else:
|
||||
loss = self.loss_op(targets, prediction_ops)
|
||||
|
||||
# Scale back the prediction.
|
||||
prediction = self.scale_back_data(prediction)
|
||||
covariance = self.scale_back_variance(covariance)
|
||||
prediction = self._scale_back_data(prediction)
|
||||
covariance = self._scale_back_variance(covariance)
|
||||
|
||||
return model.ModelOutputs(
|
||||
loss=loss,
|
||||
@ -565,7 +537,7 @@ class ARModel(model.TimeSeriesModel):
|
||||
new_state_times.set_shape((None, self.input_window_size))
|
||||
new_state_values = array_ops.concat(
|
||||
[previous_state_values,
|
||||
self.scale_data(values)], axis=1)[:, -self.input_window_size:, :]
|
||||
self._scale_data(values)], axis=1)[:, -self.input_window_size:, :]
|
||||
new_state_values.set_shape((None, self.input_window_size,
|
||||
self.num_features))
|
||||
else:
|
||||
|
@ -936,8 +936,7 @@ class InputStatisticsFromMiniBatch(object):
|
||||
start_time = variable_scope.get_variable(
|
||||
name="start_time",
|
||||
dtype=dtypes.int64,
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
shape=[],
|
||||
initializer=dtypes.int64.max,
|
||||
trainable=False)
|
||||
total_observation_count = variable_scope.get_variable(
|
||||
name="total_observation_count",
|
||||
|
@ -80,6 +80,8 @@ class TimeSeriesModel(object):
|
||||
self.dtype = dtype
|
||||
self._input_statistics = None
|
||||
self._graph_initialized = False
|
||||
self._stats_means = None
|
||||
self._stats_sigmas = None
|
||||
|
||||
# TODO(allenl): Move more of the generic machinery for generating and
|
||||
# predicting into TimeSeriesModel, and possibly share it between generate()
|
||||
@ -120,6 +122,38 @@ class TimeSeriesModel(object):
|
||||
"""
|
||||
self._graph_initialized = True
|
||||
self._input_statistics = input_statistics
|
||||
if self._input_statistics:
|
||||
self._stats_means, variances = (
|
||||
self._input_statistics.overall_feature_moments)
|
||||
self._stats_sigmas = math_ops.sqrt(variances)
|
||||
|
||||
def _scale_data(self, data):
|
||||
"""Scale data according to stats (input scale -> model scale)."""
|
||||
if self._input_statistics is not None:
|
||||
return (data - self._stats_means) / self._stats_sigmas
|
||||
else:
|
||||
return data
|
||||
|
||||
def _scale_variance(self, variance):
|
||||
"""Scale variances according to stats (input scale -> model scale)."""
|
||||
if self._input_statistics is not None:
|
||||
return variance / self._input_statistics.overall_feature_moments.variance
|
||||
else:
|
||||
return variance
|
||||
|
||||
def _scale_back_data(self, data):
|
||||
"""Scale back data according to stats (model scale -> input scale)."""
|
||||
if self._input_statistics is not None:
|
||||
return (data * self._stats_sigmas) + self._stats_means
|
||||
else:
|
||||
return data
|
||||
|
||||
def _scale_back_variance(self, variance):
|
||||
"""Scale back variances according to stats (model scale -> input scale)."""
|
||||
if self._input_statistics is not None:
|
||||
return variance * self._input_statistics.overall_feature_moments.variance
|
||||
else:
|
||||
return variance
|
||||
|
||||
def _check_graph_initialized(self):
|
||||
if not self._graph_initialized:
|
||||
@ -304,6 +338,7 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||
train_output_names,
|
||||
predict_output_names,
|
||||
num_features,
|
||||
normalize_features=False,
|
||||
dtype=dtypes.float32,
|
||||
exogenous_feature_columns=None,
|
||||
exogenous_update_condition=None,
|
||||
@ -316,6 +351,12 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||
predict_output_names: A list of products/predictions returned from
|
||||
_prediction_step.
|
||||
num_features: Number of features for the time series
|
||||
normalize_features: Boolean. If True, `values` are passed normalized to
|
||||
the model (via self._scale_data). Scaling is done for the whole window
|
||||
as a batch, which is slightly more efficient than scaling inside the
|
||||
window loop. The model must then define _scale_back_predictions, which
|
||||
may use _scale_back_data or _scale_back_variance to return predictions
|
||||
to the input scale.
|
||||
dtype: The floating point datatype to use.
|
||||
exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
|
||||
objects. See `TimeSeriesModel`.
|
||||
@ -344,9 +385,25 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||
self._exogenous_update_condition = exogenous_update_condition
|
||||
self._train_output_names = train_output_names
|
||||
self._predict_output_names = predict_output_names
|
||||
self._normalize_features = normalize_features
|
||||
self._static_unrolling_window_size_threshold = (
|
||||
static_unrolling_window_size_threshold)
|
||||
|
||||
def _scale_back_predictions(self, predictions):
|
||||
"""Return a window of predictions to input scale.
|
||||
|
||||
Args:
|
||||
predictions: A dictionary mapping from prediction names to Tensors.
|
||||
Returns:
|
||||
A dictionary with values corrected for input normalization (e.g. with
|
||||
self._scale_back_mean and possibly self._scale_back_variance). May be a
|
||||
mutated version of the argument.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"SequentialTimeSeriesModel normalized input data"
|
||||
" (normalize_features=True), but no method was provided to transform "
|
||||
"the predictions back to the input scale.")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _filtering_step(self, current_times, current_values, state, predictions):
|
||||
"""Compute a single-step loss for a batch of data.
|
||||
@ -524,6 +581,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||
self._check_graph_initialized()
|
||||
times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64)
|
||||
values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
|
||||
if self._normalize_features:
|
||||
values = self._scale_data(values)
|
||||
exogenous_regressors = self._process_exogenous_features(
|
||||
times=times,
|
||||
features={key: value for key, value in features.items()
|
||||
@ -556,6 +615,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||
# Since we have window-level additions to the loss, its per-step value is
|
||||
# misleading, so we avoid returning it.
|
||||
del outputs["loss"]
|
||||
if self._normalize_features:
|
||||
outputs = self._scale_back_predictions(outputs)
|
||||
return per_observation_loss, state, outputs
|
||||
|
||||
def predict(self, features):
|
||||
@ -583,6 +644,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||
times=predict_times, state=start_state,
|
||||
state_update_fn=_call_prediction_step,
|
||||
outputs=self._predict_output_names)
|
||||
if self._normalize_features:
|
||||
predictions = self._scale_back_predictions(predictions)
|
||||
return predictions
|
||||
|
||||
class _FakeTensorArray(object):
|
||||
|
@ -57,7 +57,9 @@ class AdderStateSpaceModel(state_space_model.StateSpaceModel):
|
||||
# TODO(allenl): Better support for multivariate series here.
|
||||
initial_value = array_ops.stack([
|
||||
math_ops.reduce_mean(
|
||||
self._input_statistics.series_start_moments.mean), 0.
|
||||
self._scale_data(
|
||||
self._input_statistics.series_start_moments.mean)),
|
||||
0.
|
||||
])
|
||||
return initial_value + variable_scope.get_variable(
|
||||
name="prior_state_mean",
|
||||
|
@ -232,6 +232,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||
+ filtering_postprocessor_names),
|
||||
predict_output_names=["mean", "covariance"],
|
||||
num_features=configuration.num_features,
|
||||
normalize_features=True,
|
||||
dtype=configuration.dtype,
|
||||
exogenous_feature_columns=configuration.exogenous_feature_columns,
|
||||
exogenous_update_condition=configuration.exogenous_update_condition,
|
||||
@ -309,15 +310,10 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||
_, _, priors_from_time = state
|
||||
times = ops.convert_to_tensor(times)
|
||||
priors_from_time = ops.convert_to_tensor(priors_from_time)
|
||||
with ops.control_dependencies([
|
||||
control_flow_ops.Assert(
|
||||
math_ops.reduce_all(priors_from_time <= times[:, 0]),
|
||||
[priors_from_time, times[:, 0]],
|
||||
summarize=100)
|
||||
]):
|
||||
times = array_ops.identity(times)
|
||||
intra_batch_gaps = array_ops.reshape(times[:, 1:] - times[:, :-1], [-1])
|
||||
starting_gaps = times[:, 0] - priors_from_time
|
||||
# Ignore negative starting gaps, since there will be transient start times
|
||||
# as inputs statistics are computed.
|
||||
starting_gaps = math_ops.maximum(times[:, 0] - priors_from_time, 0)
|
||||
# Pre-define transition matrices raised to powers (and their sums) for every
|
||||
# gap in this window. This avoids duplicate computation (for example many
|
||||
# steps will use the transition matrix raised to the first power) and
|
||||
@ -369,20 +365,15 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||
Imputed model state corresponding to the `state` argument.
|
||||
"""
|
||||
estimated_state, estimated_state_var, previous_times = state
|
||||
catchup_times = current_times - previous_times
|
||||
non_negative_assertion = control_flow_ops.Assert(
|
||||
math_ops.reduce_all(catchup_times >= 0), [
|
||||
"Negative imputation interval", catchup_times, current_times,
|
||||
previous_times
|
||||
],
|
||||
summarize=100)
|
||||
with ops.control_dependencies([non_negative_assertion]):
|
||||
transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking
|
||||
self._cached_transition_powers_and_sums(catchup_times))
|
||||
estimated_state = self._kalman_filter.predict_state_mean(
|
||||
estimated_state, transition_matrices)
|
||||
estimated_state_var = self._kalman_filter.predict_state_var(
|
||||
estimated_state_var, transition_matrices, transition_noise_sums)
|
||||
# Ignore negative imputation intervals due to transient start time
|
||||
# estimates.
|
||||
catchup_times = math_ops.maximum(current_times - previous_times, 0)
|
||||
transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking
|
||||
self._cached_transition_powers_and_sums(catchup_times))
|
||||
estimated_state = self._kalman_filter.predict_state_mean(
|
||||
estimated_state, transition_matrices)
|
||||
estimated_state_var = self._kalman_filter.predict_state_var(
|
||||
estimated_state_var, transition_matrices, transition_noise_sums)
|
||||
return (estimated_state, estimated_state_var,
|
||||
previous_times + catchup_times)
|
||||
|
||||
@ -437,6 +428,13 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||
outputs=predictions)
|
||||
return (filtered_state, predictions)
|
||||
|
||||
def _scale_back_predictions(self, predictions):
|
||||
"""Return a window of predictions to input scale."""
|
||||
predictions["mean"] = self._scale_back_data(predictions["mean"])
|
||||
predictions["covariance"] = self._scale_back_variance(
|
||||
predictions["covariance"])
|
||||
return predictions
|
||||
|
||||
def _prediction_step(self, current_times, state):
|
||||
"""Make a prediction based on `state`.
|
||||
|
||||
@ -458,7 +456,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||
"""
|
||||
estimated_state, estimated_state_var, previous_times = state
|
||||
advanced_to_current_assert = control_flow_ops.Assert(
|
||||
math_ops.reduce_all(math_ops.equal(current_times, previous_times)),
|
||||
math_ops.reduce_all(math_ops.less_equal(current_times, previous_times)),
|
||||
["Attempted to predict without imputation"])
|
||||
with ops.control_dependencies([advanced_to_current_assert]):
|
||||
observation_model = self.get_broadcasted_observation_model(current_times)
|
||||
@ -475,6 +473,9 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||
(self.num_features,)))
|
||||
predicted_obs_var.set_shape(current_times.get_shape().concatenate(
|
||||
(self.num_features, self.num_features)))
|
||||
# Not scaled back to input-scale, since this also feeds into the
|
||||
# loss. Instead, predictions are scaled back before being returned to the
|
||||
# user in _scale_back_predictions.
|
||||
predictions = {
|
||||
"mean": predicted_obs,
|
||||
"covariance": predicted_obs_var}
|
||||
@ -722,7 +723,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||
# Make sure initial latent value uncertainty is at least on the same
|
||||
# scale as noise in the data.
|
||||
covariance_multiplier = math_ops.reduce_max(
|
||||
self._input_statistics.series_start_moments.variance)
|
||||
self._scale_variance(
|
||||
self._input_statistics.series_start_moments.variance))
|
||||
return base_covariance * gen_math_ops.maximum(
|
||||
covariance_multiplier, 1.0)
|
||||
else:
|
||||
@ -920,7 +922,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||
self.get_noise_transform(), dtype=self.dtype)
|
||||
state_noise_dimension = state_noise_transform.get_shape()[1].value
|
||||
if self._input_statistics is not None:
|
||||
feature_variance = self._input_statistics.series_start_moments.variance
|
||||
feature_variance = self._scale_variance(
|
||||
self._input_statistics.series_start_moments.variance)
|
||||
initial_transition_noise_scale = math_ops.log(
|
||||
gen_math_ops.maximum(
|
||||
math_ops.reduce_mean(feature_variance) / math_ops.cast(
|
||||
@ -945,7 +948,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||
if self._input_statistics is not None:
|
||||
# Get variance across the first few values in each batch for each
|
||||
# feature, for an initial observation noise (over-)estimate.
|
||||
feature_variance = self._input_statistics.series_start_moments.variance
|
||||
feature_variance = self._scale_variance(
|
||||
self._input_statistics.series_start_moments.variance)
|
||||
else:
|
||||
feature_variance = None
|
||||
if feature_variance is not None:
|
||||
|
@ -605,6 +605,7 @@ class TimeDependentStateSpaceModel(state_space_model.StateSpaceModel):
|
||||
super(TimeDependentStateSpaceModel, self).__init__(
|
||||
configuration=state_space_model.StateSpaceModelConfiguration(
|
||||
use_observation_noise=False,
|
||||
transition_covariance_initial_log_scale_bias=5.,
|
||||
static_unrolling_window_size_threshold=
|
||||
static_unrolling_window_size_threshold))
|
||||
|
||||
|
@ -182,7 +182,8 @@ class VARMA(state_space_model.StateSpaceModel):
|
||||
# modeled as transition noise in VARMA, we set its initial value based on a
|
||||
# slight over-estimate empirical observation noise.
|
||||
if self._input_statistics is not None:
|
||||
feature_variance = self._input_statistics.series_start_moments.variance
|
||||
feature_variance = self._scale_variance(
|
||||
self._input_statistics.series_start_moments.variance)
|
||||
initial_transition_noise_scale = math_ops.log(
|
||||
math_ops.maximum(
|
||||
math_ops.reduce_mean(feature_variance), minimum_initial_variance))
|
||||
|
Loading…
Reference in New Issue
Block a user