From 85ca0e9b815fc8438f1c10379c18393783021e15 Mon Sep 17 00:00:00 2001 From: "Joshua V. Dillon" Date: Wed, 2 Jan 2019 14:31:31 -0800 Subject: [PATCH] Remove tf.contrib.timeseries dependency on TF distributions. PiperOrigin-RevId: 227582617 --- .../timeseries/python/timeseries/BUILD | 3 +- .../timeseries/python/timeseries/ar_model.py | 19 ++++--------- .../python/timeseries/math_utils.py | 28 +++++++++++++++++++ .../timeseries/state_space_models/BUILD | 2 -- .../filtering_postprocessor.py | 8 ++---- .../state_space_models/kalman_filter.py | 9 +++--- 6 files changed, 43 insertions(+), 26 deletions(-) diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index 4b90b596b28..a0c3204d41d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -361,9 +361,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_keys", + ":math_utils", ":model", ":model_utils", - "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index bcadf4094e1..a8d5e1a49dd 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -18,9 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import distributions - from tensorflow.contrib.rnn.python.ops import lstm_ops +from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import model_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures @@ -462,8 +461,8 @@ class ARModel(model.TimeSeriesModel): if self.loss == ARModel.NORMAL_LIKELIHOOD_LOSS: covariance = prediction_ops["covariance"] sigma = math_ops.sqrt(gen_math_ops.maximum(covariance, 1e-5)) - normal = distributions.Normal(loc=targets, scale=sigma) - loss_op = -math_ops.reduce_sum(normal.log_prob(prediction)) + loss_op = -math_ops.reduce_sum( + math_utils.normal_log_prob(targets, sigma, prediction)) else: assert self.loss == ARModel.SQUARED_LOSS, self.loss loss_op = math_ops.reduce_sum(math_ops.square(prediction - targets)) @@ -965,16 +964,11 @@ class AnomalyMixtureARModel(ARModel): anomaly_variance = prediction_ops["anomaly_params"] anomaly_sigma = math_ops.sqrt( gen_math_ops.maximum(anomaly_variance, 1e-5)) - normal = distributions.Normal(loc=targets, scale=anomaly_sigma) - log_prob = normal.log_prob(prediction) + log_prob = math_utils.normal_log_prob(targets, anomaly_sigma, prediction) else: assert self._anomaly_distribution == AnomalyMixtureARModel.CAUCHY_ANOMALY anomaly_scale = prediction_ops["anomaly_params"] - cauchy = distributions.StudentT( - df=array_ops.ones([], dtype=anomaly_scale.dtype), - loc=targets, - scale=anomaly_scale) - log_prob = cauchy.log_prob(prediction) + log_prob = math_utils.cauchy_log_prob(targets, anomaly_scale, prediction) return log_prob def loss_op(self, targets, prediction_ops): @@ -983,8 +977,7 @@ class AnomalyMixtureARModel(ARModel): covariance = prediction_ops["covariance"] # Normal data log probability. sigma = math_ops.sqrt(gen_math_ops.maximum(covariance, 1e-5)) - normal1 = distributions.Normal(loc=targets, scale=sigma) - log_prob1 = normal1.log_prob(prediction) + log_prob1 = math_utils.normal_log_prob(targets, sigma, prediction) log_prob1 += math_ops.log(1 - self._anomaly_prior_probability) # Anomaly log probability. log_prob2 = self._anomaly_log_prob(targets, prediction_ops) diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index aab33064386..b7375e5055e 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -21,6 +21,8 @@ from __future__ import print_function import collections import math +import numpy as np + from tensorflow.contrib import lookup from tensorflow.contrib.layers.python.layers import layers @@ -43,6 +45,32 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest +def normal_log_prob(loc, scale, x): + """Computes the Normal log pdf.""" + z = (x - loc) / scale + return -0.5 * (math_ops.square(z) + + np.log(2. * np.pi) + math_ops.log(scale)) + + +def cauchy_log_prob(loc, scale, x): + """Computes the Cauchy log pdf.""" + z = (x - loc) / scale + return (-np.log(np.pi) - math_ops.log(scale) - + math_ops.log1p(math_ops.square(z))) + + +def mvn_tril_log_prob(loc, scale_tril, x): + """Computes the MVN log pdf under tril scale. Doesn't handle batches.""" + x0 = x - loc + z = linalg_ops.matrix_triangular_solve( + scale_tril, x0[..., array_ops.newaxis])[..., 0] + log_det_cov = 2. * math_ops.reduce_sum(math_ops.log( + array_ops.matrix_diag_part(scale_tril)), axis=-1) + d = math_ops.cast(array_ops.shape(scale_tril)[-1], log_det_cov.dtype) + return -0.5 * (math_ops.reduce_sum(math_ops.square(z), axis=-1) + + d * np.log(2. * np.pi) + log_det_cov) + + def clip_covariance( covariance_matrix, maximum_variance_ratio, minimum_variance): """Enforce constraints on a covariance matrix to improve numerical stability. diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD index 125750e7639..cf5e749042a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/BUILD @@ -78,7 +78,6 @@ py_library( srcs = ["kalman_filter.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -235,7 +234,6 @@ py_library( srcs = ["filtering_postprocessor.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/timeseries/python/timeseries:math_utils", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py index e9e2ac0aaf4..3fa2fbd9f77 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor.py @@ -22,8 +22,6 @@ import abc import six -from tensorflow.contrib import distributions - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.python.framework import dtypes @@ -91,10 +89,10 @@ def cauchy_alternative_to_gaussian(current_times, current_values, outputs): """ del current_times # unused cauchy_scale = math_utils.entropy_matched_cauchy_scale(outputs["covariance"]) - individual_log_pdfs = distributions.StudentT( - df=array_ops.ones([], dtype=current_values.dtype), + individual_log_pdfs = math_utils.cauchy_log_prob( loc=outputs["mean"], - scale=cauchy_scale).log_prob(current_values) + scale=cauchy_scale, + x=current_values) return math_ops.reduce_sum(individual_log_pdfs, axis=1) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py index a614386121e..c0ec797bc5b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import distributions - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.python.framework import dtypes @@ -137,9 +135,10 @@ class KalmanFilter(object): with ops.control_dependencies([non_negative_assert]): observation_covariance_cholesky = linalg_ops.cholesky( symmetrized_observation_covariance) - log_prediction_prob = distributions.MultivariateNormalTriL( - predicted_observation, observation_covariance_cholesky).log_prob( - observation) + log_prediction_prob = math_utils.mvn_tril_log_prob( + loc=predicted_observation, + scale_tril=observation_covariance_cholesky, + x=observation) (posterior_state, posterior_state_var) = self.posterior_from_prior_state( prior_state=estimated_state,