diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 74fd40f2c91..34fa0129dd8 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -571,9 +571,8 @@ class WALSModel(object): extras = size % num_shards assignments = tf.maximum(ids // (ids_per_shard + 1), (ids - extras) // ids_per_shard) - new_ids = tf.select(assignments < extras, - ids % (ids_per_shard + 1), - (ids - extras) % ids_per_shard) + new_ids = tf.where(assignments < extras, ids % (ids_per_shard + 1), + (ids - extras) % ids_per_shard) return assignments, new_ids return func diff --git a/tensorflow/contrib/integrate/__init__.py b/tensorflow/contrib/integrate/__init__.py index 599a778ee57..e88d10c5823 100644 --- a/tensorflow/contrib/integrate/__init__.py +++ b/tensorflow/contrib/integrate/__init__.py @@ -22,10 +22,6 @@ We can use `odeint` to solve the differential equations, a prototypical example of chaotic dynamics: ```python -import numpy as np -import matplotlib.pyplot as plt -import tensorflow as tf - rho = 28.0 sigma = 10.0 beta = 8.0/3.0 diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index b5b1dbb6355..cfe2fb15985 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -29,11 +29,11 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys -from tensorflow.contrib.learn.python.learn.estimators.head import MetricKey -from tensorflow.contrib.learn.python.learn.estimators.head import PredictionKey from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor +from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey +from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 145bdcf2ee8..c9d1377ce73 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -45,6 +45,7 @@ from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn import monitors as monitor_lib from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn +from tensorflow.contrib.learn.python.learn.estimators import metric_key from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.learn.python.learn.estimators import tensor_signature from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError @@ -1108,8 +1109,9 @@ class Estimator(BaseEstimator): result = _make_metrics_ops(all_metrics, features, labels, model_fn_ops.predictions) - if 'loss' not in result: - result['loss'] = metrics_lib.streaming_mean(model_fn_ops.loss) + if metric_key.MetricKey.LOSS not in result: + result[metric_key.MetricKey.LOSS] = metrics_lib.streaming_mean( + model_fn_ops.loss) return result def _get_predict_ops(self, features): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index f3a4f409032..b38e671dc49 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -24,6 +24,8 @@ from tensorflow.contrib import losses from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn.estimators import estimator +from tensorflow.contrib.learn.python.learn.estimators import metric_key +from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.session_bundle import exporter from tensorflow.python import summary from tensorflow.python.framework import ops @@ -388,17 +390,17 @@ class _RegressionHead(_Head): def _logits_to_prediction(self, logits=None): predictions = {} if self.logits_dimension == 1: - predictions[PredictionKey.SCORES] = array_ops.squeeze( + predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze( logits, squeeze_dims=[1]) else: - predictions[PredictionKey.SCORES] = logits + predictions[prediction_key.PredictionKey.SCORES] = logits return predictions # pylint: disable=undefined-variable def _create_signature_fn(self): def _regression_signature_fn(examples, unused_features, predictions): if isinstance(predictions, dict): - score = predictions[PredictionKey.SCORES] + score = predictions[prediction_key.PredictionKey.SCORES] else: score = predictions @@ -409,11 +411,12 @@ class _RegressionHead(_Head): return _regression_signature_fn def _default_metric(self): - return {_head_prefixed(self._head_name, MetricKey.LOSS): - _weighted_average_loss_metric_spec(self._eval_loss_fn, - PredictionKey.SCORES, - self._label_name, - self._weight_column_name)} + return {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS): + _weighted_average_loss_metric_spec( + self._eval_loss_fn, + prediction_key.PredictionKey.SCORES, + self._label_name, + self._weight_column_name)} class _MultiClassHead(_Head): @@ -530,12 +533,16 @@ class _MultiClassHead(_Head): return self._logits_to_prediction(logits) def _logits_to_prediction(self, logits=None): - predictions = {PredictionKey.LOGITS: logits} + # pylint: disable=missing-docstring + predictions = {prediction_key.PredictionKey.LOGITS: logits} if self.logits_dimension == 1: - predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits) + predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid( + logits) logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits]) - predictions[PredictionKey.PROBABILITIES] = nn.softmax(logits) - predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1) + predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax( + logits) + predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax( + logits, 1) return predictions @@ -546,8 +553,9 @@ class _MultiClassHead(_Head): if isinstance(predictions, dict): default_signature = exporter.classification_signature( input_tensor=examples, - classes_tensor=predictions[PredictionKey.CLASSES], - scores_tensor=predictions[PredictionKey.PROBABILITIES]) + classes_tensor=predictions[prediction_key.PredictionKey.CLASSES], + scores_tensor=predictions[ + prediction_key.PredictionKey.PROBABILITIES]) else: default_signature = exporter.classification_signature( input_tensor=examples, @@ -558,44 +566,49 @@ class _MultiClassHead(_Head): return _classification_signature_fn def _default_metric(self): - metrics = {_head_prefixed(self._head_name, MetricKey.LOSS): - _weighted_average_loss_metric_spec(self._eval_loss_fn, - PredictionKey.LOGITS, - self._label_name, - self._weight_column_name)} + metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS): + _weighted_average_loss_metric_spec( + self._eval_loss_fn, + prediction_key.PredictionKey.LOGITS, + self._label_name, + self._weight_column_name)} # TODO(b/29366811): This currently results in both an "accuracy" and an # "accuracy/threshold_0.500000_mean" metric for binary classification. - metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = ( + metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = ( metric_spec.MetricSpec(metrics_lib.streaming_accuracy, - PredictionKey.CLASSES, self._label_name, + prediction_key.PredictionKey.CLASSES, + self._label_name, self._weight_column_name)) if self.logits_dimension == 1: - def _add_binary_metric(metric_key, metric_fn): - metrics[_head_prefixed(self._head_name, metric_key)] = ( + def _add_binary_metric(key, metric_fn): + metrics[_head_prefixed(self._head_name, key)] = ( metric_spec.MetricSpec(metric_fn, - PredictionKey.LOGISTIC, + prediction_key.PredictionKey.LOGISTIC, self._label_name, self._weight_column_name)) - _add_binary_metric(MetricKey.PREDICTION_MEAN, _predictions_streaming_mean) - _add_binary_metric(MetricKey.LABEL_MEAN, _labels_streaming_mean) + _add_binary_metric( + metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean) + _add_binary_metric( + metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean) # Also include the streaming mean of the label as an accuracy baseline, as # a reminder to users. - _add_binary_metric(MetricKey.ACCURACY_BASELINE, _labels_streaming_mean) + _add_binary_metric( + metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean) - _add_binary_metric(MetricKey.AUC, _streaming_auc) + _add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc) for threshold in self._thresholds: - _add_binary_metric(MetricKey.ACCURACY_MEAN % threshold, + _add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold, _accuracy_at_threshold(threshold)) # Precision for positive examples. - _add_binary_metric(MetricKey.PRECISION_MEAN % threshold, + _add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold, _streaming_at_threshold( metrics_lib.streaming_precision_at_thresholds, threshold),) # Recall for positive examples. - _add_binary_metric(MetricKey.RECALL_MEAN % threshold, + _add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold, _streaming_at_threshold( metrics_lib.streaming_recall_at_thresholds, threshold)) @@ -635,21 +648,24 @@ class _BinarySvmHead(_MultiClassHead): def _logits_to_prediction(self, logits=None): predictions = {} - predictions[PredictionKey.LOGITS] = logits + predictions[prediction_key.PredictionKey.LOGITS] = logits logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits]) - predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1) + predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax( + logits, 1) return predictions def _default_metric(self): - metrics = {_head_prefixed(self._head_name, MetricKey.LOSS): - _weighted_average_loss_metric_spec(self._eval_loss_fn, - PredictionKey.LOGITS, - self._label_name, - self._weight_column_name)} - metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = ( + metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS): + _weighted_average_loss_metric_spec( + self._eval_loss_fn, + prediction_key.PredictionKey.LOGITS, + self._label_name, + self._weight_column_name)} + metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = ( metric_spec.MetricSpec(metrics_lib.streaming_accuracy, - PredictionKey.CLASSES, self._label_name, + prediction_key.PredictionKey.CLASSES, + self._label_name, self._weight_column_name)) # TODO(sibyl-vie3Poto): add more metrics relevant for svms. return metrics @@ -674,12 +690,14 @@ class _MultiLabelHead(_MultiClassHead): thresholds=thresholds) def _logits_to_prediction(self, logits=None): - predictions = {PredictionKey.LOGITS: logits} + predictions = {prediction_key.PredictionKey.LOGITS: logits} if self.logits_dimension == 1: - predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits) + predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid( + logits) logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits]) - predictions[PredictionKey.PROBABILITIES] = math_ops.sigmoid(logits) - predictions[PredictionKey.CLASSES] = math_ops.to_int64( + predictions[prediction_key.PredictionKey.PROBABILITIES] = math_ops.sigmoid( + logits) + predictions[prediction_key.PredictionKey.CLASSES] = math_ops.to_int64( math_ops.greater(logits, 0)) return predictions @@ -857,15 +875,3 @@ class PredictionKey(object): LOGITS = "logits" LOGISTIC = "logistic" SCORES = "scores" - - -class MetricKey(object): - LOSS = "loss" - AUC = "auc" - PREDICTION_MEAN = "labels/prediction_mean" - LABEL_MEAN = "labels/actual_label_mean" - ACCURACY = "accuracy" - ACCURACY_BASELINE = "accuracy/baseline_label_mean" - ACCURACY_MEAN = "accuracy/threshold_%f_mean" - PRECISION_MEAN = "precision/positive_threshold_%f_mean" - RECALL_MEAN = "recall/positive_threshold_%f_mean" diff --git a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py new file mode 100644 index 00000000000..8df08e507fe --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py @@ -0,0 +1,30 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Enum for metric keys.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class MetricKey(object): + LOSS = "loss" + AUC = "auc" + PREDICTION_MEAN = "labels/prediction_mean" + LABEL_MEAN = "labels/actual_label_mean" + ACCURACY = "accuracy" + ACCURACY_BASELINE = "accuracy/baseline_label_mean" + ACCURACY_MEAN = "accuracy/threshold_%f_mean" + PRECISION_MEAN = "precision/positive_threshold_%f_mean" + RECALL_MEAN = "recall/positive_threshold_%f_mean" diff --git a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py new file mode 100644 index 00000000000..a9c0c329584 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py @@ -0,0 +1,26 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Enum for model prediction keys.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class PredictionKey(object): + CLASSES = "classes" + PROBABILITIES = "probabilities" + LOGITS = "logits" + LOGISTIC = "logistic" + SCORES = "scores" diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 34017c22931..90b56b6a971 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -763,7 +763,12 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200, computes the area under a discretized curve of precision versus recall values (computed using the aforementioned variables). The `num_thresholds` variable controls the degree of discretization with larger numbers of thresholds more - closely approximating the true AUC. + closely approximating the true AUC. The quality of the approximation may vary + dramatically depending on `num_thresholds`. + + For best results, `predictions` should be distributed approximately uniformly + in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC + approximation may be poor if this is not the case. For estimation of the metric over a stream of data, the function creates an `update_op` operation that updates these variables and returns the `auc`. diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 81dc8e9064f..9870c37d19e 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -15,9 +15,16 @@ py_library( "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", + "python/training/training.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:framework", + "//tensorflow/python:ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + ], ) py_test( @@ -93,6 +100,19 @@ py_test( ], ) +py_test( + name = "training_test", + size = "large", + srcs = ["python/training/training_test.py"], + shard_count = 3, + srcs_version = "PY2AND3", + deps = [ + ":training_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py index d2a6368d785..721f8cdf750 100644 --- a/tensorflow/contrib/training/__init__.py +++ b/tensorflow/contrib/training/__init__.py @@ -70,6 +70,11 @@ from tensorflow.contrib.training.python.training.bucket_ops import * from tensorflow.contrib.training.python.training.resample import * from tensorflow.contrib.training.python.training.sampling_ops import * from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import * +from tensorflow.contrib.training.python.training.training import add_gradients_summaries +from tensorflow.contrib.training.python.training.training import clip_gradient_norms +from tensorflow.contrib.training.python.training.training import create_train_op +from tensorflow.contrib.training.python.training.training import multiply_gradients +from tensorflow.contrib.training.python.training.training import train from tensorflow.python.util.all_util import make_all __all__ = make_all(__name__) diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py new file mode 100644 index 00000000000..e65ef6ba119 --- /dev/null +++ b/tensorflow/contrib/training/python/training/training.py @@ -0,0 +1,316 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains various routines and helper functions for training models. + +TODO(nsilberman): Port documentation. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python import summary +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import monitored_session +from tensorflow.python.training import optimizer as tf_optimizer + +# TODO(nsilberman): move add_gradients_summaries, clip_gradient_norms and +# multiply_gradients into contrib/summaries and contrib/optimizers.py +__all__ = [ + 'add_gradients_summaries', + 'clip_gradient_norms', + 'create_train_op', + 'multiply_gradients', + 'train', +] + + +def add_gradients_summaries(grads_and_vars): + """Add summaries to gradients. + + Args: + grads_and_vars: A list of gradient to variable pairs (tuples). + + Returns: + The list of created summaries. + """ + summaries = [] + for grad, var in grads_and_vars: + if grad is not None: + if isinstance(grad, ops.IndexedSlices): + grad_values = grad.values + else: + grad_values = grad + summaries.append(summary.histogram_summary( + var.op.name + ':gradient', grad_values)) + summaries.append(summary.histogram_summary( + var.op.name + ':gradient_norm', clip_ops.global_norm([grad_values]))) + else: + logging.info('Var %s has no gradient', var.op.name) + + return summaries + + +def clip_gradient_norms(gradients_to_variables, max_norm): + """Clips the gradients by the given value. + + Args: + gradients_to_variables: A list of gradient to variable pairs (tuples). + max_norm: the maximum norm value. + + Returns: + A list of clipped gradient to variable pairs. + """ + clipped_grads_and_vars = [] + for grad, var in gradients_to_variables: + if grad is not None: + if isinstance(grad, ops.IndexedSlices): + tmp = clip_ops.clip_by_norm(grad.values, max_norm) + grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape) + else: + grad = clip_ops.clip_by_norm(grad, max_norm) + clipped_grads_and_vars.append((grad, var)) + return clipped_grads_and_vars + + +def multiply_gradients(grads_and_vars, gradient_multipliers): + """Multiply specified gradients. + + Args: + grads_and_vars: A list of gradient to variable pairs (tuples). + gradient_multipliers: A map from either `Variables` or `Variable` op names + to the coefficient by which the associated gradient should be scaled. + + Returns: + The updated list of gradient to variable pairs. + + Raises: + ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers` + is empty or None or if `gradient_multipliers` is not a dictionary. + """ + if not isinstance(grads_and_vars, list): + raise ValueError('`grads_and_vars` must be a list.') + if not gradient_multipliers: + raise ValueError('`gradient_multipliers` is empty.') + if not isinstance(gradient_multipliers, dict): + raise ValueError('`gradient_multipliers` must be a dict.') + + multiplied_grads_and_vars = [] + for grad, var in grads_and_vars: + if var in gradient_multipliers or var.op.name in gradient_multipliers: + key = var if var in gradient_multipliers else var.op.name + if grad is None: + raise ValueError('Requested multiple of `None` gradient.') + + if isinstance(grad, ops.IndexedSlices): + tmp = grad.values * constant_op.constant( + gradient_multipliers[key], dtype=grad.dtype) + grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape) + else: + grad *= constant_op.constant( + gradient_multipliers[key], dtype=grad.dtype) + multiplied_grads_and_vars.append((grad, var)) + return multiplied_grads_and_vars + + +def create_train_op(total_loss, + optimizer, + global_step=None, + update_ops=None, + variables_to_train=None, + transform_grads_fn=None, + summarize_gradients=False, + gate_gradients=tf_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False): + """Creates an `Operation` that evaluates the gradients and returns the loss. + + Args: + total_loss: A `Tensor` representing the total loss. + optimizer: A tf.Optimizer to use for computing the gradients. + global_step: A `Tensor` representing the global step variable. If left as + `None`, then slim.variables.global_step() is used. + update_ops: An optional list of updates to execute. If `update_ops` is + `None`, then the update ops are set to the contents of the + `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but + it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, + a warning will be displayed. + variables_to_train: an optional list of variables to train. If None, it will + default to all tf.trainable_variables(). + transform_grads_fn: A function which takes a single argument, a list of + gradient to variable pairs (tuples), performs any requested gradient + updates, such as gradient clipping or multipliers, and returns the updated + list. + summarize_gradients: Whether or not add summaries for each gradient. + gate_gradients: How to gate the computation of gradients. See tf.Optimizer. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: Whether or not to try colocating the gradients + with the ops that generated them. + + Returns: + A `Tensor` that when evaluated, computes the gradients and returns the total + loss value. + """ + if global_step is None: + global_step = variables.get_or_create_global_step() + + # Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None. + global_update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) + if update_ops is None: + update_ops = global_update_ops + else: + update_ops = set(update_ops) + if not global_update_ops.issubset(update_ops): + logging.warning('update_ops in create_train_op does not contain all the ' + ' update_ops in GraphKeys.UPDATE_OPS') + + # Make sure update_ops are computed before total_loss. + if update_ops: + with ops.control_dependencies(update_ops): + barrier = control_flow_ops.no_op(name='update_barrier') + total_loss = control_flow_ops.with_dependencies([barrier], total_loss) + + if variables_to_train is None: + # Default to tf.trainable_variables() + variables_to_train = tf_variables.trainable_variables() + else: + # Make sure that variables_to_train are in tf.trainable_variables() + for v in variables_to_train: + assert v in tf_variables.trainable_variables() + + assert variables_to_train + + # Create the gradients. Note that apply_gradients adds the gradient + # computation to the current graph. + grads = optimizer.compute_gradients( + total_loss, + variables_to_train, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops) + + if transform_grads_fn: + grads = transform_grads_fn(grads) + + # Summarize gradients. + if summarize_gradients: + with ops.name_scope('summarize_grads'): + add_gradients_summaries(grads) + + # Create gradient updates. + grad_updates = optimizer.apply_gradients(grads, global_step=global_step) + + with ops.name_scope('train_op'): + # Make sure total_loss is valid. + total_loss = array_ops.check_numerics(total_loss, + 'LossTensor is inf or nan') + + # Ensure the train_tensor computes grad_updates. + return control_flow_ops.with_dependencies([grad_updates], total_loss) + + +def train( + train_op, + logdir, + master='', + is_chief=True, + scaffold=None, + hooks=None, + chief_only_hooks=None, + save_checkpoint_secs=600, + save_summaries_steps=100, + config=None): + """Runs the training loop. + + Args: + train_op: A `Tensor` that, when executed, will apply the gradients and + return the loss value. + logdir: The directory where the graph and checkpoints are saved. + master: The URL of the master. + is_chief: Specifies whether or not the training is being run by the primary + replica during replica training. + scaffold: An tf.train.Scaffold instance. + hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the + training loop. + chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run + inside the training loop for the chief trainer only. + save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved + using a default checkpoint saver. If `save_checkpoint_secs` is set to + `None`, then the default checkpoint saver isn't used. + save_summaries_steps: The frequency, in number of global steps, that the + summaries are written to disk using a default summary saver. If + `save_summaries_steps` is set to `None`, then the default summary saver + isn't used. + config: An instance of `tf.ConfigProto`. + + Returns: + the value of the loss function after training. + + Raises: + ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or + `save_summaries_steps` are `None. + """ + # TODO(nsilberman): move this logic into monitored_session.py + scaffold = scaffold or monitored_session.Scaffold() + + hooks = hooks or [] + + if is_chief: + session_creator = monitored_session.ChiefSessionCreator( + scaffold=scaffold, + checkpoint_dir=logdir, + master=master, + config=config) + + if chief_only_hooks: + hooks.extend(chief_only_hooks) + + hooks.append(basic_session_run_hooks.StepCounterHook( + output_dir=logdir)) + + if save_summaries_steps: + if logdir is None: + raise ValueError( + 'logdir cannot be None when save_summaries_steps is None') + hooks.append(basic_session_run_hooks.SummarySaverHook( + scaffold=scaffold, + save_steps=save_summaries_steps, + output_dir=logdir)) + + if save_checkpoint_secs: + if logdir is None: + raise ValueError( + 'logdir cannot be None when save_checkpoint_secs is None') + hooks.append(basic_session_run_hooks.CheckpointSaverHook( + logdir, save_secs=save_checkpoint_secs, scaffold=scaffold)) + else: + session_creator = monitored_session.WorkerSessionCreator( + scaffold=scaffold, master=master, config=config) + + with monitored_session.MonitoredSession( + session_creator=session_creator, hooks=hooks) as session: + loss = None + while not session.should_stop(): + loss = session.run(train_op) + return loss diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py new file mode 100644 index 00000000000..81de828a803 --- /dev/null +++ b/tensorflow/contrib/training/python/training/training_test.py @@ -0,0 +1,514 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.training.training.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np +import tensorflow as tf + + +def logistic_classifier(inputs): + return tf.contrib.layers.fully_connected( + inputs, 1, activation_fn=tf.sigmoid) + + +def batchnorm_classifier(inputs): + inputs = tf.contrib.layers.batch_norm(inputs, decay=0.1) + return tf.contrib.layers.fully_connected(inputs, 1, activation_fn=tf.sigmoid) + + +class CreateTrainOpTest(tf.test.TestCase): + + def setUp(self): + np.random.seed(0) + + # Create an easy training set: + self._inputs = np.random.rand(16, 4).astype(np.float32) + self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) + + def testUseUpdateOps(self): + with tf.Graph().as_default(): + tf.set_random_seed(0) + tf_inputs = tf.constant(self._inputs, dtype=tf.float32) + tf_labels = tf.constant(self._labels, dtype=tf.float32) + + expected_mean = np.mean(self._inputs, axis=(0)) + expected_var = np.var(self._inputs, axis=(0)) + + tf_predictions = batchnorm_classifier(tf_inputs) + tf.contrib.losses.log_loss(tf_predictions, tf_labels) + total_loss = tf.contrib.losses.get_total_loss() + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + + train_op = tf.contrib.training.create_train_op(total_loss, optimizer) + + moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0] + moving_variance = tf.contrib.framework.get_variables_by_name( + 'moving_variance')[0] + + with tf.Session() as sess: + # Initialize all variables + sess.run(tf.initialize_all_variables()) + mean, variance = sess.run([moving_mean, moving_variance]) + # After initialization moving_mean == 0 and moving_variance == 1. + self.assertAllClose(mean, [0] * 4) + self.assertAllClose(variance, [1] * 4) + + for _ in range(10): + sess.run([train_op]) + mean = moving_mean.eval() + variance = moving_variance.eval() + # After 10 updates with decay 0.1 moving_mean == expected_mean and + # moving_variance == expected_var. + self.assertAllClose(mean, expected_mean) + self.assertAllClose(variance, expected_var) + + def testEmptyUpdateOps(self): + with tf.Graph().as_default(): + tf.set_random_seed(0) + tf_inputs = tf.constant(self._inputs, dtype=tf.float32) + tf_labels = tf.constant(self._labels, dtype=tf.float32) + + tf_predictions = batchnorm_classifier(tf_inputs) + tf.contrib.losses.log_loss(tf_predictions, tf_labels) + total_loss = tf.contrib.losses.get_total_loss() + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + + train_op = tf.contrib.training.create_train_op( + total_loss, optimizer, update_ops=[]) + + moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0] + moving_variance = tf.contrib.framework.get_variables_by_name( + 'moving_variance')[0] + + with tf.Session() as sess: + # Initialize all variables + sess.run(tf.initialize_all_variables()) + mean, variance = sess.run([moving_mean, moving_variance]) + # After initialization moving_mean == 0 and moving_variance == 1. + self.assertAllClose(mean, [0] * 4) + self.assertAllClose(variance, [1] * 4) + + for _ in range(10): + sess.run([train_op]) + mean = moving_mean.eval() + variance = moving_variance.eval() + + # Since we skip update_ops the moving_vars are not updated. + self.assertAllClose(mean, [0] * 4) + self.assertAllClose(variance, [1] * 4) + + +class TrainBNClassifierTest(tf.test.TestCase): + + def setUp(self): + # Create an easy training set: + np.random.seed(0) + + self._inputs = np.zeros((16, 4)) + self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) + self._logdir = os.path.join(self.get_temp_dir(), 'tmp_bnlogs/') + + for i in range(16): + j = int(2 * self._labels[i] + np.random.randint(0, 2)) + self._inputs[i, j] = 1 + + def testTrainWithNoInitAssignCanAchieveZeroLoss(self): + g = tf.Graph() + with g.as_default(): + tf.set_random_seed(0) + tf_inputs = tf.constant(self._inputs, dtype=tf.float32) + tf_labels = tf.constant(self._labels, dtype=tf.float32) + + tf_predictions = batchnorm_classifier(tf_inputs) + tf.contrib.losses.log_loss(tf_predictions, tf_labels) + total_loss = tf.contrib.losses.get_total_loss() + + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + + train_op = tf.contrib.training.create_train_op( + total_loss, optimizer) + + loss = tf.contrib.training.train( + train_op, self._logdir, hooks=[ + tf.train.StopAtStepHook(num_steps=300) + ]) + self.assertLess(loss, .1) + + +class TrainTest(tf.test.TestCase): + + def setUp(self): + # Create an easy training set: + np.random.seed(0) + + self._inputs = np.zeros((16, 4)) + self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) + + for i in range(16): + j = int(2 * self._labels[i] + np.random.randint(0, 2)) + self._inputs[i, j] = 1 + + def testCanAchieveZeroLoss(self): + logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss') + + with tf.Graph().as_default(): + tf.set_random_seed(0) + tf_inputs = tf.constant(self._inputs, dtype=tf.float32) + tf_labels = tf.constant(self._labels, dtype=tf.float32) + + tf_predictions = logistic_classifier(tf_inputs) + tf.contrib.losses.log_loss(tf_predictions, tf_labels) + total_loss = tf.contrib.losses.get_total_loss() + + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + + train_op = tf.contrib.training.create_train_op(total_loss, optimizer) + + loss = tf.contrib.training.train( + train_op, logdir, hooks=[ + tf.train.StopAtStepHook(num_steps=300) + ]) + self.assertIsNotNone(loss) + self.assertLess(loss, .015) + + def testTrainWithLocalVariable(self): + logdir = os.path.join(self.get_temp_dir(), 'train_with_local_variable') + + with tf.Graph().as_default(): + tf.set_random_seed(0) + tf_inputs = tf.constant(self._inputs, dtype=tf.float32) + tf_labels = tf.constant(self._labels, dtype=tf.float32) + + local_multiplier = tf.contrib.framework.local_variable(1.0) + + tf_predictions = logistic_classifier(tf_inputs) * local_multiplier + tf.contrib.losses.log_loss(tf_predictions, tf_labels) + total_loss = tf.contrib.losses.get_total_loss() + + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + + train_op = tf.contrib.training.create_train_op( + total_loss, optimizer) + + loss = tf.contrib.training.train( + train_op, logdir, hooks=[ + tf.train.StopAtStepHook(num_steps=300) + ]) + self.assertIsNotNone(loss) + self.assertLess(loss, .015) + + def testResumeTrainAchievesRoughlyTheSameLoss(self): + number_of_steps = [300, 1, 5] + logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss') + + for i in range(len(number_of_steps)): + with tf.Graph().as_default(): + tf.set_random_seed(i) + tf_inputs = tf.constant(self._inputs, dtype=tf.float32) + tf_labels = tf.constant(self._labels, dtype=tf.float32) + + tf_predictions = logistic_classifier(tf_inputs) + tf.contrib.losses.log_loss(tf_predictions, tf_labels) + total_loss = tf.contrib.losses.get_total_loss() + + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + + train_op = tf.contrib.training.create_train_op( + total_loss, optimizer) + + saver = tf.train.Saver() + + loss = tf.contrib.training.train( + train_op, logdir, hooks=[ + tf.train.StopAtStepHook(num_steps=number_of_steps[i]), + tf.train.CheckpointSaverHook( + logdir, save_steps=50, saver=saver), + ]) + self.assertIsNotNone(loss) + self.assertLess(loss, .015) + + def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0): + tf_inputs = tf.constant(self._inputs, dtype=tf.float32) + tf_labels = tf.constant(self._labels, dtype=tf.float32) + + tf_predictions = logistic_classifier(tf_inputs) + tf.contrib.losses.log_loss(tf_predictions, tf_labels) + total_loss = tf.contrib.losses.get_total_loss() + + optimizer = tf.train.GradientDescentOptimizer( + learning_rate=learning_rate) + + def transform_grads_fn(grads): + if gradient_multiplier != 1.0: + variables = tf.trainable_variables() + gradient_multipliers = {var: gradient_multiplier for var in variables} + + with tf.name_scope('multiply_grads'): + return tf.contrib.training.multiply_gradients( + grads, gradient_multipliers) + else: + return grads + + return tf.contrib.training.create_train_op( + total_loss, optimizer, transform_grads_fn=transform_grads_fn) + + def testTrainWithInitFromCheckpoint(self): + logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/') + logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/') + + if tf.gfile.Exists(logdir1): # For running on jenkins. + tf.gfile.DeleteRecursively(logdir1) + if tf.gfile.Exists(logdir2): # For running on jenkins. + tf.gfile.DeleteRecursively(logdir2) + + # First, train the model one step (make sure the error is high). + with tf.Graph().as_default(): + tf.set_random_seed(0) + train_op = self.create_train_op() + saver = tf.train.Saver() + loss = tf.contrib.training.train( + train_op, logdir1, hooks=[ + tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver), + tf.train.StopAtStepHook(num_steps=1), + ], save_checkpoint_secs=None) + self.assertGreater(loss, .5) + + # Next, train the model to convergence. + with tf.Graph().as_default(): + tf.set_random_seed(1) + train_op = self.create_train_op() + saver = tf.train.Saver() + loss = tf.contrib.training.train( + train_op, logdir1, hooks=[ + tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver), + tf.train.StopAtStepHook(num_steps=300), + ], save_checkpoint_secs=None) + self.assertIsNotNone(loss) + self.assertLess(loss, .02) + + # Finally, advance the model a single step and validate that the loss is + # still low. + with tf.Graph().as_default(): + tf.set_random_seed(2) + train_op = self.create_train_op() + + model_variables = tf.all_variables() + model_path = os.path.join(logdir1, 'model.ckpt-300') + + assign_fn = tf.contrib.framework.assign_from_checkpoint_fn( + model_path, model_variables) + def init_fn(_, session): + assign_fn(session) + + loss = tf.contrib.training.train( + train_op, + logdir2, + scaffold=tf.train.Scaffold(init_fn=init_fn), + hooks=[tf.train.StopAtStepHook(num_steps=1)]) + + self.assertIsNotNone(loss) + self.assertLess(loss, .02) + + def ModelLoss(self): + tf_inputs = tf.constant(self._inputs, dtype=tf.float32) + tf_labels = tf.constant(self._labels, dtype=tf.float32) + + tf_predictions = logistic_classifier(tf_inputs) + tf.contrib.losses.log_loss(tf_predictions, tf_labels) + return tf.contrib.losses.get_total_loss() + + def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self): + logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/') + if tf.gfile.Exists(logdir): # For running on jenkins. + tf.gfile.DeleteRecursively(logdir) + + # First, train only the weights of the model. + with tf.Graph().as_default(): + tf.set_random_seed(0) + total_loss = self.ModelLoss() + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + weights = tf.contrib.framework.get_variables_by_name('weights') + + train_op = tf.contrib.training.create_train_op( + total_loss, + optimizer, + variables_to_train=weights) + + saver = tf.train.Saver() + loss = tf.contrib.training.train( + train_op, logdir, hooks=[ + tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver), + tf.train.StopAtStepHook(num_steps=200), + ]) + self.assertGreater(loss, .015) + self.assertLess(loss, .05) + + # Next, train the biases of the model. + with tf.Graph().as_default(): + tf.set_random_seed(1) + total_loss = self.ModelLoss() + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + biases = tf.contrib.framework.get_variables_by_name('biases') + + train_op = tf.contrib.training.create_train_op( + total_loss, + optimizer, + variables_to_train=biases) + + saver = tf.train.Saver() + loss = tf.contrib.training.train( + train_op, logdir, hooks=[ + tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver), + tf.train.StopAtStepHook(num_steps=300), + ]) + self.assertGreater(loss, .015) + self.assertLess(loss, .05) + + # Finally, train both weights and bias to get lower loss. + with tf.Graph().as_default(): + tf.set_random_seed(2) + total_loss = self.ModelLoss() + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + + train_op = tf.contrib.training.create_train_op(total_loss, optimizer) + saver = tf.train.Saver() + loss = tf.contrib.training.train( + train_op, logdir, hooks=[ + tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver), + tf.train.StopAtStepHook(num_steps=400), + ]) + self.assertIsNotNone(loss) + self.assertLess(loss, .015) + + def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self): + # First, train only the weights of the model. + with tf.Graph().as_default(): + tf.set_random_seed(0) + total_loss = self.ModelLoss() + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + weights, biases = tf.contrib.framework.get_variables() + + train_op = tf.contrib.training.create_train_op(total_loss, optimizer) + train_weights = tf.contrib.training.create_train_op( + total_loss, optimizer, variables_to_train=[weights]) + train_biases = tf.contrib.training.create_train_op( + total_loss, optimizer, variables_to_train=[biases]) + + with tf.Session() as sess: + # Initialize the variables. + sess.run(tf.initialize_all_variables()) + + # Get the intial weights and biases values. + weights_values, biases_values = sess.run([weights, biases]) + self.assertGreater(np.linalg.norm(weights_values), 0) + self.assertAlmostEqual(np.linalg.norm(biases_values), 0) + + # Update weights and biases. + loss = sess.run(train_op) + self.assertGreater(loss, .5) + new_weights, new_biases = sess.run([weights, biases]) + + # Check that the weights and biases have been updated. + self.assertGreater(np.linalg.norm(weights_values - new_weights), 0) + self.assertGreater(np.linalg.norm(biases_values - new_biases), 0) + + weights_values, biases_values = new_weights, new_biases + + # Update only weights. + loss = sess.run(train_weights) + self.assertGreater(loss, .5) + new_weights, new_biases = sess.run([weights, biases]) + + # Check that the weights have been updated, but biases have not. + self.assertGreater(np.linalg.norm(weights_values - new_weights), 0) + self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0) + weights_values = new_weights + + # Update only biases. + loss = sess.run(train_biases) + self.assertGreater(loss, .5) + new_weights, new_biases = sess.run([weights, biases]) + + # Check that the biases have been updated, but weights have not. + self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0) + self.assertGreater(np.linalg.norm(biases_values - new_biases), 0) + + def testTrainWithAlteredGradients(self): + # Use the same learning rate but different gradient multipliers + # to train two models. Model with equivalently larger learning + # rate (i.e., learning_rate * gradient_multiplier) has smaller + # training loss. + logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs6/') + logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs7/') + + if tf.gfile.Exists(logdir1): + tf.gfile.DeleteRecursively(logdir1) + if tf.gfile.Exists(logdir2): + tf.gfile.DeleteRecursively(logdir2) + + multipliers = [1., 1000.] + number_of_steps = 10 + losses = [] + learning_rate = 0.001 + + # First, train the model with equivalently smaller learning rate. + with tf.Graph().as_default(): + tf.set_random_seed(0) + train_op = self.create_train_op( + learning_rate=learning_rate, + gradient_multiplier=multipliers[0]) + + saver = tf.train.Saver() + + loss = tf.contrib.training.train( + train_op, logdir1, hooks=[ + tf.train.StopAtStepHook(num_steps=number_of_steps), + tf.train.CheckpointSaverHook(logdir1, save_steps=50, saver=saver), + ]) + + losses.append(loss) + self.assertGreater(loss, .5) + + # Second, train the model with equivalently larger learning rate. + with tf.Graph().as_default(): + tf.set_random_seed(0) + train_op = self.create_train_op( + learning_rate=learning_rate, + gradient_multiplier=multipliers[1]) + saver = tf.train.Saver() + + loss = tf.contrib.training.train( + train_op, logdir2, hooks=[ + tf.train.StopAtStepHook(num_steps=number_of_steps), + tf.train.CheckpointSaverHook(logdir2, save_steps=50, saver=saver), + ]) + + losses.append(loss) + self.assertIsNotNone(loss) + self.assertLess(loss, .5) + + # The loss of the model trained with larger learning rate should + # be smaller. + self.assertGreater(losses[0], losses[1]) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 79546ccd206..61db4ef42dd 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -221,13 +221,6 @@ cc_library( ], ) -cc_library( - name = "jpeg", - hdrs = ["lib/jpeg/jpeg_mem.h"], - visibility = ["//visibility:public"], - deps = [":jpeg_internal"], -) - # Test support library needed for all tests # This is currently public, but may be made internal in the # future. Try to avoid depending on it. @@ -699,9 +692,9 @@ filegroup( "platform/cuda.h", "platform/google/**/*", "platform/hadoop/**/*", - "platform/jpeg.*", - "platform/png.*", - "platform/gif.*", + "platform/gif.h", + "platform/jpeg.h", + "platform/png.h", "platform/stream_executor.*", "platform/windows/**/*", "user_ops/**/*.cu.cc", @@ -981,7 +974,10 @@ cc_library( ], exclude = [ "**/*test*", + "lib/gif/**/*", "lib/jpeg/**/*", + "platform/gif.h", + "platform/jpeg.h", "platform/**/cuda.h", "platform/**/stream_executor.h", "platform/load_library.cc", @@ -998,7 +994,10 @@ cc_library( ], exclude = [ "**/*test*", + "lib/gif/**/*", "lib/jpeg/**/*", + "platform/gif.h", + "platform/jpeg.h", "platform/**/cuda.h", "platform/**/stream_executor.h", ], @@ -1016,7 +1015,6 @@ cc_library( hdrs = tf_additional_lib_hdrs() + [ "lib/core/blocking_counter.h", "lib/core/refcount.h", - "lib/gif/gif_io.h", "lib/gtl/edit_distance.h", "lib/gtl/int_type.h", "lib/gtl/iterator_range.h", @@ -1060,18 +1058,32 @@ cc_library( ], ) +cc_library( + name = "gif_internal", + srcs = [ + "lib/gif/gif_io.cc", + "platform/gif.h", + ], + hdrs = ["lib/gif/gif_io.h"], + copts = tf_copts(), + linkopts = ["-ldl"], + deps = [ + ":lib", + "//tensorflow/core/platform/default/build_config:gif", + ], +) + cc_library( name = "jpeg_internal", - srcs = glob( - [ - "lib/jpeg/*h", - "lib/jpeg/*.cc", - ], - exclude = [ - "**/*test*", - ], - ), - hdrs = ["lib/jpeg/jpeg_handle.h"], + srcs = [ + "lib/jpeg/jpeg_handle.cc", + "lib/jpeg/jpeg_mem.cc", + "platform/jpeg.h", + ], + hdrs = [ + "lib/jpeg/jpeg_handle.h", + "lib/jpeg/jpeg_mem.h", + ], copts = tf_copts(), linkopts = ["-ldl"], deps = [ @@ -1541,7 +1553,6 @@ cc_test( srcs = ["lib/jpeg/jpeg_mem_unittest.cc"], data = glob(["lib/jpeg/testdata/*.jpg"]), deps = [ - ":jpeg", ":jpeg_internal", ":lib", ":lib_internal", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index b4f36034708..50240f4992e 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1136,8 +1136,9 @@ tf_kernel_libraries( ":eigen_helpers", ":image_resizer_state", "//tensorflow/core:framework", + "//tensorflow/core:gif_internal", "//tensorflow/core:image_ops_op_lib", - "//tensorflow/core:jpeg", + "//tensorflow/core:jpeg_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD index a63aa4d7a97..f4ff341a48a 100644 --- a/tensorflow/core/platform/default/build_config/BUILD +++ b/tensorflow/core/platform/default/build_config/BUILD @@ -76,16 +76,24 @@ cc_library( name = "platformlib", copts = tf_copts(), deps = [ + ":gif", + ":jpeg", "//tensorflow/core:protos_cc", "@com_googlesource_code_re2//:re2", "@farmhash_archive//:farmhash", - "@gif_archive//:gif", "@highwayhash//:sip_hash", - "@jpeg_archive//:jpeg", "@png_archive//:png", ], ) +cc_library( + name = "gif", + copts = tf_copts(), + deps = [ + "@gif_archive//:gif", + ], +) + cc_library( name = "jpeg", copts = tf_copts(), diff --git a/tensorflow/g3doc/api_docs/python/contrib.integrate.md b/tensorflow/g3doc/api_docs/python/contrib.integrate.md new file mode 100644 index 00000000000..dc2c16a0dac --- /dev/null +++ b/tensorflow/g3doc/api_docs/python/contrib.integrate.md @@ -0,0 +1,135 @@ + + +# Integrate (contrib) +[TOC] + +Integration and ODE solvers for TensorFlow. + +## Example: Lorenz attractor + +We can use `odeint` to solve the +[Lorentz system](https://en.wikipedia.org/wiki/Lorenz_system) of ordinary +differential equations, a prototypical example of chaotic dynamics: + +```python +rho = 28.0 +sigma = 10.0 +beta = 8.0/3.0 + +def lorenz_equation(state, t): + x, y, z = tf.unpack(state) + dx = sigma * (y - x) + dy = x * (rho - z) - y + dz = x * y - beta * z + return tf.pack([dx, dy, dz]) + +init_state = tf.constant([0, 2, 20], dtype=tf.float64) +t = np.linspace(0, 50, num=5000) +tensor_state, tensor_info = tf.contrib.integrate.odeint( + lorenz_equation, init_state, t, full_output=True) + +sess = tf.Session() +state, info = sess.run([tensor_state, tensor_info]) +x, y, z = state.T +plt.plot(x, z) +``` + +