Update generated Python Op docs.
Change: 137988710
This commit is contained in:
parent
2f8318865e
commit
7c89c44112
@ -571,9 +571,8 @@ class WALSModel(object):
|
||||
extras = size % num_shards
|
||||
assignments = tf.maximum(ids // (ids_per_shard + 1),
|
||||
(ids - extras) // ids_per_shard)
|
||||
new_ids = tf.select(assignments < extras,
|
||||
ids % (ids_per_shard + 1),
|
||||
(ids - extras) % ids_per_shard)
|
||||
new_ids = tf.where(assignments < extras, ids % (ids_per_shard + 1),
|
||||
(ids - extras) % ids_per_shard)
|
||||
return assignments, new_ids
|
||||
return func
|
||||
|
||||
|
@ -22,10 +22,6 @@ We can use `odeint` to solve the
|
||||
differential equations, a prototypical example of chaotic dynamics:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import tensorflow as tf
|
||||
|
||||
rho = 28.0
|
||||
sigma = 10.0
|
||||
beta = 8.0/3.0
|
||||
|
@ -29,11 +29,11 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import ModeKeys
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import MetricKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import PredictionKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.prediction_key import PredictionKey
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestEstimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.random_forest import TensorForestLossHook
|
||||
from tensorflow.contrib.learn.python.learn.estimators.run_config import RunConfig
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
|
||||
from tensorflow.contrib.learn.python.learn import trainable
|
||||
from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
|
||||
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config
|
||||
from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
|
||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
||||
@ -1108,8 +1109,9 @@ class Estimator(BaseEstimator):
|
||||
|
||||
result = _make_metrics_ops(all_metrics, features, labels,
|
||||
model_fn_ops.predictions)
|
||||
if 'loss' not in result:
|
||||
result['loss'] = metrics_lib.streaming_mean(model_fn_ops.loss)
|
||||
if metric_key.MetricKey.LOSS not in result:
|
||||
result[metric_key.MetricKey.LOSS] = metrics_lib.streaming_mean(
|
||||
model_fn_ops.loss)
|
||||
return result
|
||||
|
||||
def _get_predict_ops(self, features):
|
||||
|
@ -24,6 +24,8 @@ from tensorflow.contrib import losses
|
||||
from tensorflow.contrib import metrics as metrics_lib
|
||||
from tensorflow.contrib.learn.python.learn import metric_spec
|
||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.contrib.session_bundle import exporter
|
||||
from tensorflow.python import summary
|
||||
from tensorflow.python.framework import ops
|
||||
@ -388,17 +390,17 @@ class _RegressionHead(_Head):
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {}
|
||||
if self.logits_dimension == 1:
|
||||
predictions[PredictionKey.SCORES] = array_ops.squeeze(
|
||||
predictions[prediction_key.PredictionKey.SCORES] = array_ops.squeeze(
|
||||
logits, squeeze_dims=[1])
|
||||
else:
|
||||
predictions[PredictionKey.SCORES] = logits
|
||||
predictions[prediction_key.PredictionKey.SCORES] = logits
|
||||
return predictions
|
||||
|
||||
# pylint: disable=undefined-variable
|
||||
def _create_signature_fn(self):
|
||||
def _regression_signature_fn(examples, unused_features, predictions):
|
||||
if isinstance(predictions, dict):
|
||||
score = predictions[PredictionKey.SCORES]
|
||||
score = predictions[prediction_key.PredictionKey.SCORES]
|
||||
else:
|
||||
score = predictions
|
||||
|
||||
@ -409,11 +411,12 @@ class _RegressionHead(_Head):
|
||||
return _regression_signature_fn
|
||||
|
||||
def _default_metric(self):
|
||||
return {_head_prefixed(self._head_name, MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
||||
PredictionKey.SCORES,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
return {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(
|
||||
self._eval_loss_fn,
|
||||
prediction_key.PredictionKey.SCORES,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
|
||||
|
||||
class _MultiClassHead(_Head):
|
||||
@ -530,12 +533,16 @@ class _MultiClassHead(_Head):
|
||||
return self._logits_to_prediction(logits)
|
||||
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {PredictionKey.LOGITS: logits}
|
||||
# pylint: disable=missing-docstring
|
||||
predictions = {prediction_key.PredictionKey.LOGITS: logits}
|
||||
if self.logits_dimension == 1:
|
||||
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
|
||||
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
|
||||
logits)
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
predictions[PredictionKey.PROBABILITIES] = nn.softmax(logits)
|
||||
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
|
||||
predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
|
||||
logits)
|
||||
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
|
||||
logits, 1)
|
||||
|
||||
return predictions
|
||||
|
||||
@ -546,8 +553,9 @@ class _MultiClassHead(_Head):
|
||||
if isinstance(predictions, dict):
|
||||
default_signature = exporter.classification_signature(
|
||||
input_tensor=examples,
|
||||
classes_tensor=predictions[PredictionKey.CLASSES],
|
||||
scores_tensor=predictions[PredictionKey.PROBABILITIES])
|
||||
classes_tensor=predictions[prediction_key.PredictionKey.CLASSES],
|
||||
scores_tensor=predictions[
|
||||
prediction_key.PredictionKey.PROBABILITIES])
|
||||
else:
|
||||
default_signature = exporter.classification_signature(
|
||||
input_tensor=examples,
|
||||
@ -558,44 +566,49 @@ class _MultiClassHead(_Head):
|
||||
return _classification_signature_fn
|
||||
|
||||
def _default_metric(self):
|
||||
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
||||
PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(
|
||||
self._eval_loss_fn,
|
||||
prediction_key.PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
|
||||
# TODO(b/29366811): This currently results in both an "accuracy" and an
|
||||
# "accuracy/threshold_0.500000_mean" metric for binary classification.
|
||||
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
|
||||
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
|
||||
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
||||
PredictionKey.CLASSES, self._label_name,
|
||||
prediction_key.PredictionKey.CLASSES,
|
||||
self._label_name,
|
||||
self._weight_column_name))
|
||||
if self.logits_dimension == 1:
|
||||
def _add_binary_metric(metric_key, metric_fn):
|
||||
metrics[_head_prefixed(self._head_name, metric_key)] = (
|
||||
def _add_binary_metric(key, metric_fn):
|
||||
metrics[_head_prefixed(self._head_name, key)] = (
|
||||
metric_spec.MetricSpec(metric_fn,
|
||||
PredictionKey.LOGISTIC,
|
||||
prediction_key.PredictionKey.LOGISTIC,
|
||||
self._label_name,
|
||||
self._weight_column_name))
|
||||
_add_binary_metric(MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
|
||||
_add_binary_metric(MetricKey.LABEL_MEAN, _labels_streaming_mean)
|
||||
_add_binary_metric(
|
||||
metric_key.MetricKey.PREDICTION_MEAN, _predictions_streaming_mean)
|
||||
_add_binary_metric(
|
||||
metric_key.MetricKey.LABEL_MEAN, _labels_streaming_mean)
|
||||
|
||||
# Also include the streaming mean of the label as an accuracy baseline, as
|
||||
# a reminder to users.
|
||||
_add_binary_metric(MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
|
||||
_add_binary_metric(
|
||||
metric_key.MetricKey.ACCURACY_BASELINE, _labels_streaming_mean)
|
||||
|
||||
_add_binary_metric(MetricKey.AUC, _streaming_auc)
|
||||
_add_binary_metric(metric_key.MetricKey.AUC, _streaming_auc)
|
||||
|
||||
for threshold in self._thresholds:
|
||||
_add_binary_metric(MetricKey.ACCURACY_MEAN % threshold,
|
||||
_add_binary_metric(metric_key.MetricKey.ACCURACY_MEAN % threshold,
|
||||
_accuracy_at_threshold(threshold))
|
||||
# Precision for positive examples.
|
||||
_add_binary_metric(MetricKey.PRECISION_MEAN % threshold,
|
||||
_add_binary_metric(metric_key.MetricKey.PRECISION_MEAN % threshold,
|
||||
_streaming_at_threshold(
|
||||
metrics_lib.streaming_precision_at_thresholds,
|
||||
threshold),)
|
||||
# Recall for positive examples.
|
||||
_add_binary_metric(MetricKey.RECALL_MEAN % threshold,
|
||||
_add_binary_metric(metric_key.MetricKey.RECALL_MEAN % threshold,
|
||||
_streaming_at_threshold(
|
||||
metrics_lib.streaming_recall_at_thresholds,
|
||||
threshold))
|
||||
@ -635,21 +648,24 @@ class _BinarySvmHead(_MultiClassHead):
|
||||
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {}
|
||||
predictions[PredictionKey.LOGITS] = logits
|
||||
predictions[prediction_key.PredictionKey.LOGITS] = logits
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
predictions[PredictionKey.CLASSES] = math_ops.argmax(logits, 1)
|
||||
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
|
||||
logits, 1)
|
||||
|
||||
return predictions
|
||||
|
||||
def _default_metric(self):
|
||||
metrics = {_head_prefixed(self._head_name, MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(self._eval_loss_fn,
|
||||
PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
metrics[_head_prefixed(self._head_name, MetricKey.ACCURACY)] = (
|
||||
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
|
||||
_weighted_average_loss_metric_spec(
|
||||
self._eval_loss_fn,
|
||||
prediction_key.PredictionKey.LOGITS,
|
||||
self._label_name,
|
||||
self._weight_column_name)}
|
||||
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
|
||||
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
|
||||
PredictionKey.CLASSES, self._label_name,
|
||||
prediction_key.PredictionKey.CLASSES,
|
||||
self._label_name,
|
||||
self._weight_column_name))
|
||||
# TODO(sibyl-vie3Poto): add more metrics relevant for svms.
|
||||
return metrics
|
||||
@ -674,12 +690,14 @@ class _MultiLabelHead(_MultiClassHead):
|
||||
thresholds=thresholds)
|
||||
|
||||
def _logits_to_prediction(self, logits=None):
|
||||
predictions = {PredictionKey.LOGITS: logits}
|
||||
predictions = {prediction_key.PredictionKey.LOGITS: logits}
|
||||
if self.logits_dimension == 1:
|
||||
predictions[PredictionKey.LOGISTIC] = math_ops.sigmoid(logits)
|
||||
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
|
||||
logits)
|
||||
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
|
||||
predictions[PredictionKey.PROBABILITIES] = math_ops.sigmoid(logits)
|
||||
predictions[PredictionKey.CLASSES] = math_ops.to_int64(
|
||||
predictions[prediction_key.PredictionKey.PROBABILITIES] = math_ops.sigmoid(
|
||||
logits)
|
||||
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.to_int64(
|
||||
math_ops.greater(logits, 0))
|
||||
return predictions
|
||||
|
||||
@ -857,15 +875,3 @@ class PredictionKey(object):
|
||||
LOGITS = "logits"
|
||||
LOGISTIC = "logistic"
|
||||
SCORES = "scores"
|
||||
|
||||
|
||||
class MetricKey(object):
|
||||
LOSS = "loss"
|
||||
AUC = "auc"
|
||||
PREDICTION_MEAN = "labels/prediction_mean"
|
||||
LABEL_MEAN = "labels/actual_label_mean"
|
||||
ACCURACY = "accuracy"
|
||||
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
|
||||
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
|
||||
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
|
||||
RECALL_MEAN = "recall/positive_threshold_%f_mean"
|
||||
|
@ -0,0 +1,30 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Enum for metric keys."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class MetricKey(object):
|
||||
LOSS = "loss"
|
||||
AUC = "auc"
|
||||
PREDICTION_MEAN = "labels/prediction_mean"
|
||||
LABEL_MEAN = "labels/actual_label_mean"
|
||||
ACCURACY = "accuracy"
|
||||
ACCURACY_BASELINE = "accuracy/baseline_label_mean"
|
||||
ACCURACY_MEAN = "accuracy/threshold_%f_mean"
|
||||
PRECISION_MEAN = "precision/positive_threshold_%f_mean"
|
||||
RECALL_MEAN = "recall/positive_threshold_%f_mean"
|
@ -0,0 +1,26 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Enum for model prediction keys."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class PredictionKey(object):
|
||||
CLASSES = "classes"
|
||||
PROBABILITIES = "probabilities"
|
||||
LOGITS = "logits"
|
||||
LOGISTIC = "logistic"
|
||||
SCORES = "scores"
|
@ -763,7 +763,12 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200,
|
||||
computes the area under a discretized curve of precision versus recall values
|
||||
(computed using the aforementioned variables). The `num_thresholds` variable
|
||||
controls the degree of discretization with larger numbers of thresholds more
|
||||
closely approximating the true AUC.
|
||||
closely approximating the true AUC. The quality of the approximation may vary
|
||||
dramatically depending on `num_thresholds`.
|
||||
|
||||
For best results, `predictions` should be distributed approximately uniformly
|
||||
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
|
||||
approximation may be poor if this is not the case.
|
||||
|
||||
For estimation of the metric over a stream of data, the function creates an
|
||||
`update_op` operation that updates these variables and returns the `auc`.
|
||||
|
@ -15,9 +15,16 @@ py_library(
|
||||
"python/training/resample.py",
|
||||
"python/training/sampling_ops.py",
|
||||
"python/training/sequence_queueing_state_saver.py",
|
||||
"python/training/training.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -93,6 +100,19 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "training_test",
|
||||
size = "large",
|
||||
srcs = ["python/training/training_test.py"],
|
||||
shard_count = 3,
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":training_py",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -70,6 +70,11 @@ from tensorflow.contrib.training.python.training.bucket_ops import *
|
||||
from tensorflow.contrib.training.python.training.resample import *
|
||||
from tensorflow.contrib.training.python.training.sampling_ops import *
|
||||
from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import *
|
||||
from tensorflow.contrib.training.python.training.training import add_gradients_summaries
|
||||
from tensorflow.contrib.training.python.training.training import clip_gradient_norms
|
||||
from tensorflow.contrib.training.python.training.training import create_train_op
|
||||
from tensorflow.contrib.training.python.training.training import multiply_gradients
|
||||
from tensorflow.contrib.training.python.training.training import train
|
||||
from tensorflow.python.util.all_util import make_all
|
||||
|
||||
__all__ = make_all(__name__)
|
||||
|
316
tensorflow/contrib/training/python/training/training.py
Normal file
316
tensorflow/contrib/training/python/training/training.py
Normal file
@ -0,0 +1,316 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains various routines and helper functions for training models.
|
||||
|
||||
TODO(nsilberman): Port documentation.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.python import summary
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import optimizer as tf_optimizer
|
||||
|
||||
# TODO(nsilberman): move add_gradients_summaries, clip_gradient_norms and
|
||||
# multiply_gradients into contrib/summaries and contrib/optimizers.py
|
||||
__all__ = [
|
||||
'add_gradients_summaries',
|
||||
'clip_gradient_norms',
|
||||
'create_train_op',
|
||||
'multiply_gradients',
|
||||
'train',
|
||||
]
|
||||
|
||||
|
||||
def add_gradients_summaries(grads_and_vars):
|
||||
"""Add summaries to gradients.
|
||||
|
||||
Args:
|
||||
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||
|
||||
Returns:
|
||||
The list of created summaries.
|
||||
"""
|
||||
summaries = []
|
||||
for grad, var in grads_and_vars:
|
||||
if grad is not None:
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
grad_values = grad.values
|
||||
else:
|
||||
grad_values = grad
|
||||
summaries.append(summary.histogram_summary(
|
||||
var.op.name + ':gradient', grad_values))
|
||||
summaries.append(summary.histogram_summary(
|
||||
var.op.name + ':gradient_norm', clip_ops.global_norm([grad_values])))
|
||||
else:
|
||||
logging.info('Var %s has no gradient', var.op.name)
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
def clip_gradient_norms(gradients_to_variables, max_norm):
|
||||
"""Clips the gradients by the given value.
|
||||
|
||||
Args:
|
||||
gradients_to_variables: A list of gradient to variable pairs (tuples).
|
||||
max_norm: the maximum norm value.
|
||||
|
||||
Returns:
|
||||
A list of clipped gradient to variable pairs.
|
||||
"""
|
||||
clipped_grads_and_vars = []
|
||||
for grad, var in gradients_to_variables:
|
||||
if grad is not None:
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
tmp = clip_ops.clip_by_norm(grad.values, max_norm)
|
||||
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
|
||||
else:
|
||||
grad = clip_ops.clip_by_norm(grad, max_norm)
|
||||
clipped_grads_and_vars.append((grad, var))
|
||||
return clipped_grads_and_vars
|
||||
|
||||
|
||||
def multiply_gradients(grads_and_vars, gradient_multipliers):
|
||||
"""Multiply specified gradients.
|
||||
|
||||
Args:
|
||||
grads_and_vars: A list of gradient to variable pairs (tuples).
|
||||
gradient_multipliers: A map from either `Variables` or `Variable` op names
|
||||
to the coefficient by which the associated gradient should be scaled.
|
||||
|
||||
Returns:
|
||||
The updated list of gradient to variable pairs.
|
||||
|
||||
Raises:
|
||||
ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
|
||||
is empty or None or if `gradient_multipliers` is not a dictionary.
|
||||
"""
|
||||
if not isinstance(grads_and_vars, list):
|
||||
raise ValueError('`grads_and_vars` must be a list.')
|
||||
if not gradient_multipliers:
|
||||
raise ValueError('`gradient_multipliers` is empty.')
|
||||
if not isinstance(gradient_multipliers, dict):
|
||||
raise ValueError('`gradient_multipliers` must be a dict.')
|
||||
|
||||
multiplied_grads_and_vars = []
|
||||
for grad, var in grads_and_vars:
|
||||
if var in gradient_multipliers or var.op.name in gradient_multipliers:
|
||||
key = var if var in gradient_multipliers else var.op.name
|
||||
if grad is None:
|
||||
raise ValueError('Requested multiple of `None` gradient.')
|
||||
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
tmp = grad.values * constant_op.constant(
|
||||
gradient_multipliers[key], dtype=grad.dtype)
|
||||
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
|
||||
else:
|
||||
grad *= constant_op.constant(
|
||||
gradient_multipliers[key], dtype=grad.dtype)
|
||||
multiplied_grads_and_vars.append((grad, var))
|
||||
return multiplied_grads_and_vars
|
||||
|
||||
|
||||
def create_train_op(total_loss,
|
||||
optimizer,
|
||||
global_step=None,
|
||||
update_ops=None,
|
||||
variables_to_train=None,
|
||||
transform_grads_fn=None,
|
||||
summarize_gradients=False,
|
||||
gate_gradients=tf_optimizer.Optimizer.GATE_OP,
|
||||
aggregation_method=None,
|
||||
colocate_gradients_with_ops=False):
|
||||
"""Creates an `Operation` that evaluates the gradients and returns the loss.
|
||||
|
||||
Args:
|
||||
total_loss: A `Tensor` representing the total loss.
|
||||
optimizer: A tf.Optimizer to use for computing the gradients.
|
||||
global_step: A `Tensor` representing the global step variable. If left as
|
||||
`None`, then slim.variables.global_step() is used.
|
||||
update_ops: An optional list of updates to execute. If `update_ops` is
|
||||
`None`, then the update ops are set to the contents of the
|
||||
`tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
|
||||
it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
|
||||
a warning will be displayed.
|
||||
variables_to_train: an optional list of variables to train. If None, it will
|
||||
default to all tf.trainable_variables().
|
||||
transform_grads_fn: A function which takes a single argument, a list of
|
||||
gradient to variable pairs (tuples), performs any requested gradient
|
||||
updates, such as gradient clipping or multipliers, and returns the updated
|
||||
list.
|
||||
summarize_gradients: Whether or not add summaries for each gradient.
|
||||
gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
|
||||
aggregation_method: Specifies the method used to combine gradient terms.
|
||||
Valid values are defined in the class `AggregationMethod`.
|
||||
colocate_gradients_with_ops: Whether or not to try colocating the gradients
|
||||
with the ops that generated them.
|
||||
|
||||
Returns:
|
||||
A `Tensor` that when evaluated, computes the gradients and returns the total
|
||||
loss value.
|
||||
"""
|
||||
if global_step is None:
|
||||
global_step = variables.get_or_create_global_step()
|
||||
|
||||
# Update ops use GraphKeys.UPDATE_OPS collection if update_ops is None.
|
||||
global_update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
|
||||
if update_ops is None:
|
||||
update_ops = global_update_ops
|
||||
else:
|
||||
update_ops = set(update_ops)
|
||||
if not global_update_ops.issubset(update_ops):
|
||||
logging.warning('update_ops in create_train_op does not contain all the '
|
||||
' update_ops in GraphKeys.UPDATE_OPS')
|
||||
|
||||
# Make sure update_ops are computed before total_loss.
|
||||
if update_ops:
|
||||
with ops.control_dependencies(update_ops):
|
||||
barrier = control_flow_ops.no_op(name='update_barrier')
|
||||
total_loss = control_flow_ops.with_dependencies([barrier], total_loss)
|
||||
|
||||
if variables_to_train is None:
|
||||
# Default to tf.trainable_variables()
|
||||
variables_to_train = tf_variables.trainable_variables()
|
||||
else:
|
||||
# Make sure that variables_to_train are in tf.trainable_variables()
|
||||
for v in variables_to_train:
|
||||
assert v in tf_variables.trainable_variables()
|
||||
|
||||
assert variables_to_train
|
||||
|
||||
# Create the gradients. Note that apply_gradients adds the gradient
|
||||
# computation to the current graph.
|
||||
grads = optimizer.compute_gradients(
|
||||
total_loss,
|
||||
variables_to_train,
|
||||
gate_gradients=gate_gradients,
|
||||
aggregation_method=aggregation_method,
|
||||
colocate_gradients_with_ops=colocate_gradients_with_ops)
|
||||
|
||||
if transform_grads_fn:
|
||||
grads = transform_grads_fn(grads)
|
||||
|
||||
# Summarize gradients.
|
||||
if summarize_gradients:
|
||||
with ops.name_scope('summarize_grads'):
|
||||
add_gradients_summaries(grads)
|
||||
|
||||
# Create gradient updates.
|
||||
grad_updates = optimizer.apply_gradients(grads, global_step=global_step)
|
||||
|
||||
with ops.name_scope('train_op'):
|
||||
# Make sure total_loss is valid.
|
||||
total_loss = array_ops.check_numerics(total_loss,
|
||||
'LossTensor is inf or nan')
|
||||
|
||||
# Ensure the train_tensor computes grad_updates.
|
||||
return control_flow_ops.with_dependencies([grad_updates], total_loss)
|
||||
|
||||
|
||||
def train(
|
||||
train_op,
|
||||
logdir,
|
||||
master='',
|
||||
is_chief=True,
|
||||
scaffold=None,
|
||||
hooks=None,
|
||||
chief_only_hooks=None,
|
||||
save_checkpoint_secs=600,
|
||||
save_summaries_steps=100,
|
||||
config=None):
|
||||
"""Runs the training loop.
|
||||
|
||||
Args:
|
||||
train_op: A `Tensor` that, when executed, will apply the gradients and
|
||||
return the loss value.
|
||||
logdir: The directory where the graph and checkpoints are saved.
|
||||
master: The URL of the master.
|
||||
is_chief: Specifies whether or not the training is being run by the primary
|
||||
replica during replica training.
|
||||
scaffold: An tf.train.Scaffold instance.
|
||||
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
|
||||
training loop.
|
||||
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
|
||||
inside the training loop for the chief trainer only.
|
||||
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
|
||||
using a default checkpoint saver. If `save_checkpoint_secs` is set to
|
||||
`None`, then the default checkpoint saver isn't used.
|
||||
save_summaries_steps: The frequency, in number of global steps, that the
|
||||
summaries are written to disk using a default summary saver. If
|
||||
`save_summaries_steps` is set to `None`, then the default summary saver
|
||||
isn't used.
|
||||
config: An instance of `tf.ConfigProto`.
|
||||
|
||||
Returns:
|
||||
the value of the loss function after training.
|
||||
|
||||
Raises:
|
||||
ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
|
||||
`save_summaries_steps` are `None.
|
||||
"""
|
||||
# TODO(nsilberman): move this logic into monitored_session.py
|
||||
scaffold = scaffold or monitored_session.Scaffold()
|
||||
|
||||
hooks = hooks or []
|
||||
|
||||
if is_chief:
|
||||
session_creator = monitored_session.ChiefSessionCreator(
|
||||
scaffold=scaffold,
|
||||
checkpoint_dir=logdir,
|
||||
master=master,
|
||||
config=config)
|
||||
|
||||
if chief_only_hooks:
|
||||
hooks.extend(chief_only_hooks)
|
||||
|
||||
hooks.append(basic_session_run_hooks.StepCounterHook(
|
||||
output_dir=logdir))
|
||||
|
||||
if save_summaries_steps:
|
||||
if logdir is None:
|
||||
raise ValueError(
|
||||
'logdir cannot be None when save_summaries_steps is None')
|
||||
hooks.append(basic_session_run_hooks.SummarySaverHook(
|
||||
scaffold=scaffold,
|
||||
save_steps=save_summaries_steps,
|
||||
output_dir=logdir))
|
||||
|
||||
if save_checkpoint_secs:
|
||||
if logdir is None:
|
||||
raise ValueError(
|
||||
'logdir cannot be None when save_checkpoint_secs is None')
|
||||
hooks.append(basic_session_run_hooks.CheckpointSaverHook(
|
||||
logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
|
||||
else:
|
||||
session_creator = monitored_session.WorkerSessionCreator(
|
||||
scaffold=scaffold, master=master, config=config)
|
||||
|
||||
with monitored_session.MonitoredSession(
|
||||
session_creator=session_creator, hooks=hooks) as session:
|
||||
loss = None
|
||||
while not session.should_stop():
|
||||
loss = session.run(train_op)
|
||||
return loss
|
514
tensorflow/contrib/training/python/training/training_test.py
Normal file
514
tensorflow/contrib/training/python/training/training_test.py
Normal file
@ -0,0 +1,514 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.contrib.training.training."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def logistic_classifier(inputs):
|
||||
return tf.contrib.layers.fully_connected(
|
||||
inputs, 1, activation_fn=tf.sigmoid)
|
||||
|
||||
|
||||
def batchnorm_classifier(inputs):
|
||||
inputs = tf.contrib.layers.batch_norm(inputs, decay=0.1)
|
||||
return tf.contrib.layers.fully_connected(inputs, 1, activation_fn=tf.sigmoid)
|
||||
|
||||
|
||||
class CreateTrainOpTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
np.random.seed(0)
|
||||
|
||||
# Create an easy training set:
|
||||
self._inputs = np.random.rand(16, 4).astype(np.float32)
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
|
||||
def testUseUpdateOps(self):
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
expected_mean = np.mean(self._inputs, axis=(0))
|
||||
expected_var = np.var(self._inputs, axis=(0))
|
||||
|
||||
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
|
||||
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
|
||||
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||
'moving_variance')[0]
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Initialize all variables
|
||||
sess.run(tf.initialize_all_variables())
|
||||
mean, variance = sess.run([moving_mean, moving_variance])
|
||||
# After initialization moving_mean == 0 and moving_variance == 1.
|
||||
self.assertAllClose(mean, [0] * 4)
|
||||
self.assertAllClose(variance, [1] * 4)
|
||||
|
||||
for _ in range(10):
|
||||
sess.run([train_op])
|
||||
mean = moving_mean.eval()
|
||||
variance = moving_variance.eval()
|
||||
# After 10 updates with decay 0.1 moving_mean == expected_mean and
|
||||
# moving_variance == expected_var.
|
||||
self.assertAllClose(mean, expected_mean)
|
||||
self.assertAllClose(variance, expected_var)
|
||||
|
||||
def testEmptyUpdateOps(self):
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, update_ops=[])
|
||||
|
||||
moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[0]
|
||||
moving_variance = tf.contrib.framework.get_variables_by_name(
|
||||
'moving_variance')[0]
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Initialize all variables
|
||||
sess.run(tf.initialize_all_variables())
|
||||
mean, variance = sess.run([moving_mean, moving_variance])
|
||||
# After initialization moving_mean == 0 and moving_variance == 1.
|
||||
self.assertAllClose(mean, [0] * 4)
|
||||
self.assertAllClose(variance, [1] * 4)
|
||||
|
||||
for _ in range(10):
|
||||
sess.run([train_op])
|
||||
mean = moving_mean.eval()
|
||||
variance = moving_variance.eval()
|
||||
|
||||
# Since we skip update_ops the moving_vars are not updated.
|
||||
self.assertAllClose(mean, [0] * 4)
|
||||
self.assertAllClose(variance, [1] * 4)
|
||||
|
||||
|
||||
class TrainBNClassifierTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create an easy training set:
|
||||
np.random.seed(0)
|
||||
|
||||
self._inputs = np.zeros((16, 4))
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
self._logdir = os.path.join(self.get_temp_dir(), 'tmp_bnlogs/')
|
||||
|
||||
for i in range(16):
|
||||
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||
self._inputs[i, j] = 1
|
||||
|
||||
def testTrainWithNoInitAssignCanAchieveZeroLoss(self):
|
||||
g = tf.Graph()
|
||||
with g.as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = batchnorm_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, self._logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=300)
|
||||
])
|
||||
self.assertLess(loss, .1)
|
||||
|
||||
|
||||
class TrainTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Create an easy training set:
|
||||
np.random.seed(0)
|
||||
|
||||
self._inputs = np.zeros((16, 4))
|
||||
self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
|
||||
|
||||
for i in range(16):
|
||||
j = int(2 * self._labels[i] + np.random.randint(0, 2))
|
||||
self._inputs[i, j] = 1
|
||||
|
||||
def testCanAchieveZeroLoss(self):
|
||||
logdir = os.path.join(self.get_temp_dir(), 'can_achieve_zero_loss')
|
||||
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=300)
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def testTrainWithLocalVariable(self):
|
||||
logdir = os.path.join(self.get_temp_dir(), 'train_with_local_variable')
|
||||
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
local_multiplier = tf.contrib.framework.local_variable(1.0)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs) * local_multiplier
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=300)
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def testResumeTrainAchievesRoughlyTheSameLoss(self):
|
||||
number_of_steps = [300, 1, 5]
|
||||
logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss')
|
||||
|
||||
for i in range(len(number_of_steps)):
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(i)
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=number_of_steps[i]),
|
||||
tf.train.CheckpointSaverHook(
|
||||
logdir, save_steps=50, saver=saver),
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0):
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
total_loss = tf.contrib.losses.get_total_loss()
|
||||
|
||||
optimizer = tf.train.GradientDescentOptimizer(
|
||||
learning_rate=learning_rate)
|
||||
|
||||
def transform_grads_fn(grads):
|
||||
if gradient_multiplier != 1.0:
|
||||
variables = tf.trainable_variables()
|
||||
gradient_multipliers = {var: gradient_multiplier for var in variables}
|
||||
|
||||
with tf.name_scope('multiply_grads'):
|
||||
return tf.contrib.training.multiply_gradients(
|
||||
grads, gradient_multipliers)
|
||||
else:
|
||||
return grads
|
||||
|
||||
return tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, transform_grads_fn=transform_grads_fn)
|
||||
|
||||
def testTrainWithInitFromCheckpoint(self):
|
||||
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
|
||||
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
|
||||
|
||||
if tf.gfile.Exists(logdir1): # For running on jenkins.
|
||||
tf.gfile.DeleteRecursively(logdir1)
|
||||
if tf.gfile.Exists(logdir2): # For running on jenkins.
|
||||
tf.gfile.DeleteRecursively(logdir2)
|
||||
|
||||
# First, train the model one step (make sure the error is high).
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
train_op = self.create_train_op()
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir1, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=1),
|
||||
], save_checkpoint_secs=None)
|
||||
self.assertGreater(loss, .5)
|
||||
|
||||
# Next, train the model to convergence.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(1)
|
||||
train_op = self.create_train_op()
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir1, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir1, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=300),
|
||||
], save_checkpoint_secs=None)
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .02)
|
||||
|
||||
# Finally, advance the model a single step and validate that the loss is
|
||||
# still low.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(2)
|
||||
train_op = self.create_train_op()
|
||||
|
||||
model_variables = tf.all_variables()
|
||||
model_path = os.path.join(logdir1, 'model.ckpt-300')
|
||||
|
||||
assign_fn = tf.contrib.framework.assign_from_checkpoint_fn(
|
||||
model_path, model_variables)
|
||||
def init_fn(_, session):
|
||||
assign_fn(session)
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op,
|
||||
logdir2,
|
||||
scaffold=tf.train.Scaffold(init_fn=init_fn),
|
||||
hooks=[tf.train.StopAtStepHook(num_steps=1)])
|
||||
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .02)
|
||||
|
||||
def ModelLoss(self):
|
||||
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
|
||||
tf_labels = tf.constant(self._labels, dtype=tf.float32)
|
||||
|
||||
tf_predictions = logistic_classifier(tf_inputs)
|
||||
tf.contrib.losses.log_loss(tf_predictions, tf_labels)
|
||||
return tf.contrib.losses.get_total_loss()
|
||||
|
||||
def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
|
||||
logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
|
||||
if tf.gfile.Exists(logdir): # For running on jenkins.
|
||||
tf.gfile.DeleteRecursively(logdir)
|
||||
|
||||
# First, train only the weights of the model.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
weights = tf.contrib.framework.get_variables_by_name('weights')
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss,
|
||||
optimizer,
|
||||
variables_to_train=weights)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=200),
|
||||
])
|
||||
self.assertGreater(loss, .015)
|
||||
self.assertLess(loss, .05)
|
||||
|
||||
# Next, train the biases of the model.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(1)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
biases = tf.contrib.framework.get_variables_by_name('biases')
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(
|
||||
total_loss,
|
||||
optimizer,
|
||||
variables_to_train=biases)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=300),
|
||||
])
|
||||
self.assertGreater(loss, .015)
|
||||
self.assertLess(loss, .05)
|
||||
|
||||
# Finally, train both weights and bias to get lower loss.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(2)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
saver = tf.train.Saver()
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir, hooks=[
|
||||
tf.train.CheckpointSaverHook(logdir, save_steps=1, saver=saver),
|
||||
tf.train.StopAtStepHook(num_steps=400),
|
||||
])
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .015)
|
||||
|
||||
def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
|
||||
# First, train only the weights of the model.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
total_loss = self.ModelLoss()
|
||||
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
|
||||
weights, biases = tf.contrib.framework.get_variables()
|
||||
|
||||
train_op = tf.contrib.training.create_train_op(total_loss, optimizer)
|
||||
train_weights = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, variables_to_train=[weights])
|
||||
train_biases = tf.contrib.training.create_train_op(
|
||||
total_loss, optimizer, variables_to_train=[biases])
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Initialize the variables.
|
||||
sess.run(tf.initialize_all_variables())
|
||||
|
||||
# Get the intial weights and biases values.
|
||||
weights_values, biases_values = sess.run([weights, biases])
|
||||
self.assertGreater(np.linalg.norm(weights_values), 0)
|
||||
self.assertAlmostEqual(np.linalg.norm(biases_values), 0)
|
||||
|
||||
# Update weights and biases.
|
||||
loss = sess.run(train_op)
|
||||
self.assertGreater(loss, .5)
|
||||
new_weights, new_biases = sess.run([weights, biases])
|
||||
|
||||
# Check that the weights and biases have been updated.
|
||||
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
|
||||
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
|
||||
|
||||
weights_values, biases_values = new_weights, new_biases
|
||||
|
||||
# Update only weights.
|
||||
loss = sess.run(train_weights)
|
||||
self.assertGreater(loss, .5)
|
||||
new_weights, new_biases = sess.run([weights, biases])
|
||||
|
||||
# Check that the weights have been updated, but biases have not.
|
||||
self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
|
||||
self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0)
|
||||
weights_values = new_weights
|
||||
|
||||
# Update only biases.
|
||||
loss = sess.run(train_biases)
|
||||
self.assertGreater(loss, .5)
|
||||
new_weights, new_biases = sess.run([weights, biases])
|
||||
|
||||
# Check that the biases have been updated, but weights have not.
|
||||
self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0)
|
||||
self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
|
||||
|
||||
def testTrainWithAlteredGradients(self):
|
||||
# Use the same learning rate but different gradient multipliers
|
||||
# to train two models. Model with equivalently larger learning
|
||||
# rate (i.e., learning_rate * gradient_multiplier) has smaller
|
||||
# training loss.
|
||||
logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs6/')
|
||||
logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs7/')
|
||||
|
||||
if tf.gfile.Exists(logdir1):
|
||||
tf.gfile.DeleteRecursively(logdir1)
|
||||
if tf.gfile.Exists(logdir2):
|
||||
tf.gfile.DeleteRecursively(logdir2)
|
||||
|
||||
multipliers = [1., 1000.]
|
||||
number_of_steps = 10
|
||||
losses = []
|
||||
learning_rate = 0.001
|
||||
|
||||
# First, train the model with equivalently smaller learning rate.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
train_op = self.create_train_op(
|
||||
learning_rate=learning_rate,
|
||||
gradient_multiplier=multipliers[0])
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir1, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=number_of_steps),
|
||||
tf.train.CheckpointSaverHook(logdir1, save_steps=50, saver=saver),
|
||||
])
|
||||
|
||||
losses.append(loss)
|
||||
self.assertGreater(loss, .5)
|
||||
|
||||
# Second, train the model with equivalently larger learning rate.
|
||||
with tf.Graph().as_default():
|
||||
tf.set_random_seed(0)
|
||||
train_op = self.create_train_op(
|
||||
learning_rate=learning_rate,
|
||||
gradient_multiplier=multipliers[1])
|
||||
saver = tf.train.Saver()
|
||||
|
||||
loss = tf.contrib.training.train(
|
||||
train_op, logdir2, hooks=[
|
||||
tf.train.StopAtStepHook(num_steps=number_of_steps),
|
||||
tf.train.CheckpointSaverHook(logdir2, save_steps=50, saver=saver),
|
||||
])
|
||||
|
||||
losses.append(loss)
|
||||
self.assertIsNotNone(loss)
|
||||
self.assertLess(loss, .5)
|
||||
|
||||
# The loss of the model trained with larger learning rate should
|
||||
# be smaller.
|
||||
self.assertGreater(losses[0], losses[1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -221,13 +221,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jpeg",
|
||||
hdrs = ["lib/jpeg/jpeg_mem.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":jpeg_internal"],
|
||||
)
|
||||
|
||||
# Test support library needed for all tests
|
||||
# This is currently public, but may be made internal in the
|
||||
# future. Try to avoid depending on it.
|
||||
@ -699,9 +692,9 @@ filegroup(
|
||||
"platform/cuda.h",
|
||||
"platform/google/**/*",
|
||||
"platform/hadoop/**/*",
|
||||
"platform/jpeg.*",
|
||||
"platform/png.*",
|
||||
"platform/gif.*",
|
||||
"platform/gif.h",
|
||||
"platform/jpeg.h",
|
||||
"platform/png.h",
|
||||
"platform/stream_executor.*",
|
||||
"platform/windows/**/*",
|
||||
"user_ops/**/*.cu.cc",
|
||||
@ -981,7 +974,10 @@ cc_library(
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"lib/gif/**/*",
|
||||
"lib/jpeg/**/*",
|
||||
"platform/gif.h",
|
||||
"platform/jpeg.h",
|
||||
"platform/**/cuda.h",
|
||||
"platform/**/stream_executor.h",
|
||||
"platform/load_library.cc",
|
||||
@ -998,7 +994,10 @@ cc_library(
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
"lib/gif/**/*",
|
||||
"lib/jpeg/**/*",
|
||||
"platform/gif.h",
|
||||
"platform/jpeg.h",
|
||||
"platform/**/cuda.h",
|
||||
"platform/**/stream_executor.h",
|
||||
],
|
||||
@ -1016,7 +1015,6 @@ cc_library(
|
||||
hdrs = tf_additional_lib_hdrs() + [
|
||||
"lib/core/blocking_counter.h",
|
||||
"lib/core/refcount.h",
|
||||
"lib/gif/gif_io.h",
|
||||
"lib/gtl/edit_distance.h",
|
||||
"lib/gtl/int_type.h",
|
||||
"lib/gtl/iterator_range.h",
|
||||
@ -1060,18 +1058,32 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gif_internal",
|
||||
srcs = [
|
||||
"lib/gif/gif_io.cc",
|
||||
"platform/gif.h",
|
||||
],
|
||||
hdrs = ["lib/gif/gif_io.h"],
|
||||
copts = tf_copts(),
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":lib",
|
||||
"//tensorflow/core/platform/default/build_config:gif",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jpeg_internal",
|
||||
srcs = glob(
|
||||
[
|
||||
"lib/jpeg/*h",
|
||||
"lib/jpeg/*.cc",
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
],
|
||||
),
|
||||
hdrs = ["lib/jpeg/jpeg_handle.h"],
|
||||
srcs = [
|
||||
"lib/jpeg/jpeg_handle.cc",
|
||||
"lib/jpeg/jpeg_mem.cc",
|
||||
"platform/jpeg.h",
|
||||
],
|
||||
hdrs = [
|
||||
"lib/jpeg/jpeg_handle.h",
|
||||
"lib/jpeg/jpeg_mem.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
@ -1541,7 +1553,6 @@ cc_test(
|
||||
srcs = ["lib/jpeg/jpeg_mem_unittest.cc"],
|
||||
data = glob(["lib/jpeg/testdata/*.jpg"]),
|
||||
deps = [
|
||||
":jpeg",
|
||||
":jpeg_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
|
@ -1136,8 +1136,9 @@ tf_kernel_libraries(
|
||||
":eigen_helpers",
|
||||
":image_resizer_state",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gif_internal",
|
||||
"//tensorflow/core:image_ops_op_lib",
|
||||
"//tensorflow/core:jpeg",
|
||||
"//tensorflow/core:jpeg_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -76,16 +76,24 @@ cc_library(
|
||||
name = "platformlib",
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":gif",
|
||||
":jpeg",
|
||||
"//tensorflow/core:protos_cc",
|
||||
"@com_googlesource_code_re2//:re2",
|
||||
"@farmhash_archive//:farmhash",
|
||||
"@gif_archive//:gif",
|
||||
"@highwayhash//:sip_hash",
|
||||
"@jpeg_archive//:jpeg",
|
||||
"@png_archive//:png",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gif",
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
"@gif_archive//:gif",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "jpeg",
|
||||
copts = tf_copts(),
|
||||
|
135
tensorflow/g3doc/api_docs/python/contrib.integrate.md
Normal file
135
tensorflow/g3doc/api_docs/python/contrib.integrate.md
Normal file
@ -0,0 +1,135 @@
|
||||
<!-- This file is machine generated: DO NOT EDIT! -->
|
||||
|
||||
# Integrate (contrib)
|
||||
[TOC]
|
||||
|
||||
Integration and ODE solvers for TensorFlow.
|
||||
|
||||
## Example: Lorenz attractor
|
||||
|
||||
We can use `odeint` to solve the
|
||||
[Lorentz system](https://en.wikipedia.org/wiki/Lorenz_system) of ordinary
|
||||
differential equations, a prototypical example of chaotic dynamics:
|
||||
|
||||
```python
|
||||
rho = 28.0
|
||||
sigma = 10.0
|
||||
beta = 8.0/3.0
|
||||
|
||||
def lorenz_equation(state, t):
|
||||
x, y, z = tf.unpack(state)
|
||||
dx = sigma * (y - x)
|
||||
dy = x * (rho - z) - y
|
||||
dz = x * y - beta * z
|
||||
return tf.pack([dx, dy, dz])
|
||||
|
||||
init_state = tf.constant([0, 2, 20], dtype=tf.float64)
|
||||
t = np.linspace(0, 50, num=5000)
|
||||
tensor_state, tensor_info = tf.contrib.integrate.odeint(
|
||||
lorenz_equation, init_state, t, full_output=True)
|
||||
|
||||
sess = tf.Session()
|
||||
state, info = sess.run([tensor_state, tensor_info])
|
||||
x, y, z = state.T
|
||||
plt.plot(x, z)
|
||||
```
|
||||
|
||||
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../images/lorenz_attractor.png" alt>
|
||||
</div>
|
||||
|
||||
## Ops
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.contrib.integrate.odeint(func, y0, t, rtol=1e-06, atol=1e-12, method=None, options=None, full_output=False, name=None)` {#odeint}
|
||||
|
||||
Integrate a system of ordinary differential equations.
|
||||
|
||||
Solves the initial value problem for a non-stiff system of first order ode-s:
|
||||
|
||||
```
|
||||
dy/dt = func(y, t), y(t[0]) = y0
|
||||
```
|
||||
|
||||
where y is a Tensor of any shape.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# solve `dy/dt = -y`, corresponding to exponential decay
|
||||
tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
|
||||
=> [1, exp(-1), exp(-2)]
|
||||
```
|
||||
|
||||
Output dtypes and numerical precision are based on the dtypes of the inputs
|
||||
`y0` and `t`.
|
||||
|
||||
Currently, implements 5th order Runge-Kutta with adaptive step size control
|
||||
and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
|
||||
method of `scipy.integrate.ode` and MATLAB's `ode45`.
|
||||
|
||||
Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
|
||||
Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
|
||||
doi:10.2307/2008219
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`func`</b>: Function that maps a Tensor holding the state `y` and a scalar Tensor
|
||||
`t` into a Tensor of state derivatives with respect to time.
|
||||
* <b>`y0`</b>: N-D Tensor giving starting value of `y` at time point `t[0]`. May
|
||||
have any floating point or complex dtype.
|
||||
* <b>`t`</b>: 1-D Tensor holding a sequence of time points for which to solve for
|
||||
`y`. The initial time point should be the first element of this sequence,
|
||||
and each time must be larger than the previous time. May have any floating
|
||||
point dtype. If not provided as a Tensor, converted to a Tensor with
|
||||
float64 dtype.
|
||||
* <b>`rtol`</b>: optional float64 Tensor specifying an upper bound on relative error,
|
||||
per element of `y`.
|
||||
* <b>`atol`</b>: optional float64 Tensor specifying an upper bound on absolute error,
|
||||
per element of `y`.
|
||||
* <b>`method`</b>: optional string indicating the integration method to use. Currently,
|
||||
the only valid option is `'dopri5'`.
|
||||
* <b>`options`</b>: optional dict of configuring options for the indicated integration
|
||||
method. Can only be provided if a `method` is explicitly set. For
|
||||
`'dopri5'`, valid options include:
|
||||
* first_step: an initial guess for the size of the first integration
|
||||
(current default: 1.0, but may later be changed to use heuristics based
|
||||
on the gradient).
|
||||
* safety: safety factor for adaptive step control, generally a constant
|
||||
in the range 0.8-1 (default: 0.9).
|
||||
* ifactor: maximum factor by which the adaptive step may be increased
|
||||
(default: 10.0).
|
||||
* dfactor: maximum factor by which the adpative step may be decreased
|
||||
(default: 0.2).
|
||||
* max_num_steps: integer maximum number of integrate steps between time
|
||||
points in `t` (default: 1000).
|
||||
* <b>`full_output`</b>: optional boolean. If True, `odeint` returns a tuple
|
||||
`(y, info_dict)` describing the integration process.
|
||||
* <b>`name`</b>: Optional name for this operation.
|
||||
|
||||
##### Returns:
|
||||
|
||||
|
||||
* <b>`y`</b>: (N+1)-D tensor, where the first dimension corresponds to different
|
||||
time points. Contains the solved value of y for each desired time point in
|
||||
`t`, with the initial value `y0` being the first element along the first
|
||||
dimension.
|
||||
* <b>`info_dict`</b>: only if `full_output == True`. A dict with the following values:
|
||||
* num_func_evals: integer Tensor counting the number of function
|
||||
evaluations.
|
||||
* integrate_points: 1D float64 Tensor with the upper bound of each
|
||||
integration time step.
|
||||
* error_ratio: 1D float Tensor with the estimated ratio of the integration
|
||||
error to the error tolerance at each integration step. An ratio greater
|
||||
than 1 corresponds to rejected steps.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`ValueError`</b>: if an invalid `method` is provided.
|
||||
* <b>`TypeError`</b>: if `options` is supplied without `method`, or if `t` or `y0` has
|
||||
an invalid dtype.
|
||||
|
||||
|
@ -300,7 +300,12 @@ This value is ultimately returned as `auc`, an idempotent operation that
|
||||
computes the area under a discretized curve of precision versus recall values
|
||||
(computed using the aforementioned variables). The `num_thresholds` variable
|
||||
controls the degree of discretization with larger numbers of thresholds more
|
||||
closely approximating the true AUC.
|
||||
closely approximating the true AUC. The quality of the approximation may vary
|
||||
dramatically depending on `num_thresholds`.
|
||||
|
||||
For best results, `predictions` should be distributed approximately uniformly
|
||||
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
|
||||
approximation may be poor if this is not the case.
|
||||
|
||||
For estimation of the metric over a stream of data, the function creates an
|
||||
`update_op` operation that updates these variables and returns the `auc`.
|
||||
|
@ -0,0 +1,90 @@
|
||||
### `tf.contrib.integrate.odeint(func, y0, t, rtol=1e-06, atol=1e-12, method=None, options=None, full_output=False, name=None)` {#odeint}
|
||||
|
||||
Integrate a system of ordinary differential equations.
|
||||
|
||||
Solves the initial value problem for a non-stiff system of first order ode-s:
|
||||
|
||||
```
|
||||
dy/dt = func(y, t), y(t[0]) = y0
|
||||
```
|
||||
|
||||
where y is a Tensor of any shape.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
# solve `dy/dt = -y`, corresponding to exponential decay
|
||||
tf.contrib.integrate.odeint(lambda y, _: -y, 1.0, [0, 1, 2])
|
||||
=> [1, exp(-1), exp(-2)]
|
||||
```
|
||||
|
||||
Output dtypes and numerical precision are based on the dtypes of the inputs
|
||||
`y0` and `t`.
|
||||
|
||||
Currently, implements 5th order Runge-Kutta with adaptive step size control
|
||||
and dense output, using the Dormand-Prince method. Similar to the 'dopri5'
|
||||
method of `scipy.integrate.ode` and MATLAB's `ode45`.
|
||||
|
||||
Based on: Shampine, Lawrence F. (1986), "Some Practical Runge-Kutta Formulas",
|
||||
Mathematics of Computation, American Mathematical Society, 46 (173): 135-150,
|
||||
doi:10.2307/2008219
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`func`</b>: Function that maps a Tensor holding the state `y` and a scalar Tensor
|
||||
`t` into a Tensor of state derivatives with respect to time.
|
||||
* <b>`y0`</b>: N-D Tensor giving starting value of `y` at time point `t[0]`. May
|
||||
have any floating point or complex dtype.
|
||||
* <b>`t`</b>: 1-D Tensor holding a sequence of time points for which to solve for
|
||||
`y`. The initial time point should be the first element of this sequence,
|
||||
and each time must be larger than the previous time. May have any floating
|
||||
point dtype. If not provided as a Tensor, converted to a Tensor with
|
||||
float64 dtype.
|
||||
* <b>`rtol`</b>: optional float64 Tensor specifying an upper bound on relative error,
|
||||
per element of `y`.
|
||||
* <b>`atol`</b>: optional float64 Tensor specifying an upper bound on absolute error,
|
||||
per element of `y`.
|
||||
* <b>`method`</b>: optional string indicating the integration method to use. Currently,
|
||||
the only valid option is `'dopri5'`.
|
||||
* <b>`options`</b>: optional dict of configuring options for the indicated integration
|
||||
method. Can only be provided if a `method` is explicitly set. For
|
||||
`'dopri5'`, valid options include:
|
||||
* first_step: an initial guess for the size of the first integration
|
||||
(current default: 1.0, but may later be changed to use heuristics based
|
||||
on the gradient).
|
||||
* safety: safety factor for adaptive step control, generally a constant
|
||||
in the range 0.8-1 (default: 0.9).
|
||||
* ifactor: maximum factor by which the adaptive step may be increased
|
||||
(default: 10.0).
|
||||
* dfactor: maximum factor by which the adpative step may be decreased
|
||||
(default: 0.2).
|
||||
* max_num_steps: integer maximum number of integrate steps between time
|
||||
points in `t` (default: 1000).
|
||||
* <b>`full_output`</b>: optional boolean. If True, `odeint` returns a tuple
|
||||
`(y, info_dict)` describing the integration process.
|
||||
* <b>`name`</b>: Optional name for this operation.
|
||||
|
||||
##### Returns:
|
||||
|
||||
|
||||
* <b>`y`</b>: (N+1)-D tensor, where the first dimension corresponds to different
|
||||
time points. Contains the solved value of y for each desired time point in
|
||||
`t`, with the initial value `y0` being the first element along the first
|
||||
dimension.
|
||||
* <b>`info_dict`</b>: only if `full_output == True`. A dict with the following values:
|
||||
* num_func_evals: integer Tensor counting the number of function
|
||||
evaluations.
|
||||
* integrate_points: 1D float64 Tensor with the upper bound of each
|
||||
integration time step.
|
||||
* error_ratio: 1D float Tensor with the estimated ratio of the integration
|
||||
error to the error tolerance at each integration step. An ratio greater
|
||||
than 1 corresponds to rejected steps.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`ValueError`</b>: if an invalid `method` is provided.
|
||||
* <b>`TypeError`</b>: if `options` is supplied without `method`, or if `t` or `y0` has
|
||||
an invalid dtype.
|
||||
|
@ -14,7 +14,12 @@ This value is ultimately returned as `auc`, an idempotent operation that
|
||||
computes the area under a discretized curve of precision versus recall values
|
||||
(computed using the aforementioned variables). The `num_thresholds` variable
|
||||
controls the degree of discretization with larger numbers of thresholds more
|
||||
closely approximating the true AUC.
|
||||
closely approximating the true AUC. The quality of the approximation may vary
|
||||
dramatically depending on `num_thresholds`.
|
||||
|
||||
For best results, `predictions` should be distributed approximately uniformly
|
||||
in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
|
||||
approximation may be poor if this is not the case.
|
||||
|
||||
For estimation of the metric over a stream of data, the function creates an
|
||||
`update_op` operation that updates these variables and returns the `auc`.
|
||||
|
@ -72,7 +72,7 @@ protocol buffer file in the call to `save()`.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.Saver.__init__(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=1, pad_step_number=False)` {#Saver.__init__}
|
||||
#### `tf.train.Saver.__init__(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=2, pad_step_number=False)` {#Saver.__init__}
|
||||
|
||||
Creates a `Saver`.
|
||||
|
||||
|
@ -913,6 +913,9 @@
|
||||
* [`Transformer`](../../api_docs/python/contrib.graph_editor.md#Transformer)
|
||||
* [`ts`](../../api_docs/python/contrib.graph_editor.md#ts)
|
||||
|
||||
* **[Integrate (contrib)](../../api_docs/python/contrib.integrate.md)**:
|
||||
* [`odeint`](../../api_docs/python/contrib.integrate.md#odeint)
|
||||
|
||||
* **[Layers (contrib)](../../api_docs/python/contrib.layers.md)**:
|
||||
* [`apply_regularization`](../../api_docs/python/contrib.layers.md#apply_regularization)
|
||||
* [`avg_pool2d`](../../api_docs/python/contrib.layers.md#avg_pool2d)
|
||||
|
@ -1532,7 +1532,7 @@ protocol buffer file in the call to `save()`.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.Saver.__init__(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=1, pad_step_number=False)` {#Saver.__init__}
|
||||
#### `tf.train.Saver.__init__(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=2, pad_step_number=False)` {#Saver.__init__}
|
||||
|
||||
Creates a `Saver`.
|
||||
|
||||
|
@ -163,8 +163,19 @@ if __name__ == '__main__':
|
||||
_GetMatrixUnaryFunctorGradientTest(tf.matrix_inverse,
|
||||
dtype, shape))
|
||||
setattr(MatrixUnaryFunctorGradientTest,
|
||||
'testMatrixUnaryFunctorGradient_' + name,
|
||||
_GetMatrixUnaryFunctorGradientTest(tf.matrix_determinant,
|
||||
dtype, shape))
|
||||
'testMatrixDeterminantGradient_' + name,
|
||||
_GetMatrixUnaryFunctorGradientTest(tf.matrix_determinant, dtype,
|
||||
shape))
|
||||
|
||||
# Tests for gradients of matrix_solve_ls
|
||||
for dtype in np.float32, np.float64:
|
||||
for rows in 2, 5, 10:
|
||||
for cols in 2, 5, 10:
|
||||
for l2_regularization in 0.0, 0.001, 1.0:
|
||||
shape = (rows, cols)
|
||||
setattr(MatrixBinaryFunctorGradientTest,
|
||||
'testMatrixSolveLsGradient_' + name,
|
||||
_GetMatrixBinaryFunctorGradientTest(tf.matrix_solve_ls, dtype,
|
||||
shape))
|
||||
|
||||
tf.test.main()
|
||||
|
@ -74,6 +74,92 @@ def _MatrixSolveGrad(op, grad):
|
||||
return (grad_a, grad_b)
|
||||
|
||||
|
||||
@ops.RegisterGradient("MatrixSolveLs")
|
||||
def _MatrixSolveLsGrad(op, grad):
|
||||
"""Gradients for MatrixSolveLs."""
|
||||
|
||||
# TODO(rmlarsen): The implementation could be more efficient:
|
||||
# a) Output the Cholesky factorization from forward op instead of
|
||||
# recomputing it here.
|
||||
# b) Implement a symmetric rank-k update op instead of computing
|
||||
# x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
|
||||
|
||||
def _overdetermined(op, grad):
|
||||
"""Gradients for the overdetermined case of MatrixSolveLs.
|
||||
|
||||
This is the backprop for the solution to the normal equations of the first
|
||||
kind:
|
||||
X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B
|
||||
which solve the least squares problem
|
||||
min ||A * X - B||_F^2 + lambda ||X||_F^2.
|
||||
"""
|
||||
a = op.inputs[0]
|
||||
b = op.inputs[1]
|
||||
l2_regularizer = op.inputs[2]
|
||||
x = op.outputs[0]
|
||||
a_shape = array_ops.shape(a)
|
||||
batch_shape = a_shape[:-2]
|
||||
n = a_shape[-1]
|
||||
|
||||
identity = linalg_ops.eye(n, batch_shape=batch_shape, dtype=a.dtype)
|
||||
gramian = math_ops.batch_matmul(
|
||||
a, a, adj_x=True) + l2_regularizer * identity
|
||||
chol = linalg_ops.cholesky(gramian)
|
||||
# Temporary z = (A^T * A + lambda * I)^{-1} * grad.
|
||||
z = linalg_ops.cholesky_solve(chol, grad)
|
||||
xzt = math_ops.batch_matmul(x, z, adj_y=True)
|
||||
zx_sym = xzt + array_ops.matrix_transpose(xzt)
|
||||
grad_a = -math_ops.batch_matmul(a, zx_sym) + math_ops.batch_matmul(
|
||||
b, z, adj_y=True)
|
||||
grad_b = math_ops.batch_matmul(a, z)
|
||||
return (grad_a, grad_b, None)
|
||||
|
||||
def _underdetermined(op, grad):
|
||||
"""Gradients for the underdetermined case of MatrixSolveLs.
|
||||
|
||||
This is the backprop for the solution to the normal equations of the second
|
||||
kind:
|
||||
X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
|
||||
that (for lambda=0) solve the least squares problem
|
||||
min ||X||_F subject to A*X = B.
|
||||
"""
|
||||
a = op.inputs[0]
|
||||
b = op.inputs[1]
|
||||
l2_regularizer = op.inputs[2]
|
||||
a_shape = array_ops.shape(a)
|
||||
batch_shape = a_shape[:-2]
|
||||
m = a_shape[-2]
|
||||
|
||||
identity = linalg_ops.eye(m, batch_shape=batch_shape, dtype=a.dtype)
|
||||
gramian = math_ops.batch_matmul(
|
||||
a, a, adj_y=True) + l2_regularizer * identity
|
||||
chol = linalg_ops.cholesky(gramian)
|
||||
grad_b = linalg_ops.cholesky_solve(chol, math_ops.batch_matmul(a, grad))
|
||||
# Temporary z = (A * A^T + lambda * I)^{-1} * B.
|
||||
z = linalg_ops.cholesky_solve(chol, b)
|
||||
bz = -math_ops.batch_matmul(grad_b, z, adj_y=True)
|
||||
bz_sym = bz + array_ops.matrix_transpose(bz)
|
||||
grad_a = math_ops.batch_matmul(bz_sym, a) + math_ops.batch_matmul(z, grad)
|
||||
return (grad_a, grad_b, None)
|
||||
|
||||
fast = op.get_attr("fast")
|
||||
if fast is False:
|
||||
raise ValueError("Gradient not defined for fast=False")
|
||||
matrix_shape = op.inputs[0].get_shape()[-2:]
|
||||
if matrix_shape.is_fully_defined():
|
||||
if matrix_shape[-2] >= matrix_shape[-1]:
|
||||
return _overdetermined(op, grad)
|
||||
else:
|
||||
return _underdetermined(op, grad)
|
||||
else:
|
||||
# We have to defer determining the shape to runtime and use
|
||||
# conditional execution of the appropriate graph.
|
||||
matrix_shape = array_ops.shape(op.inputs[0])[-2:]
|
||||
return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
|
||||
lambda: _overdetermined(op, grad),
|
||||
lambda: _underdetermined(op, grad))
|
||||
|
||||
|
||||
@ops.RegisterGradient("MatrixTriangularSolve")
|
||||
def _MatrixTriangularSolveGrad(op, grad):
|
||||
"""Gradient for MatrixTriangularSolve."""
|
||||
@ -129,6 +215,6 @@ def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
|
||||
# symmetrize and take the lower triangle
|
||||
grad_a = array_ops.matrix_band_part(
|
||||
grad_a + array_ops.matrix_transpose(grad_a), -1, 0)
|
||||
grad_a = array_ops.matrix_set_diag(grad_a, 0.5 *
|
||||
array_ops.matrix_diag_part(grad_a))
|
||||
grad_a = array_ops.matrix_set_diag(grad_a,
|
||||
0.5 * array_ops.matrix_diag_part(grad_a))
|
||||
return grad_a
|
||||
|
@ -899,7 +899,7 @@ class Saver(object):
|
||||
builder=None,
|
||||
defer_build=False,
|
||||
allow_empty=False,
|
||||
write_version=saver_pb2.SaverDef.V1,
|
||||
write_version=saver_pb2.SaverDef.V2,
|
||||
pad_step_number=False):
|
||||
"""Creates a `Saver`.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user